Skip to content

Commit 13edf47

Browse files
zonglinpengfacebook-github-bot
authored andcommitted
use random manager than random in specDB (#9)
Summary: Pull Request resolved: #9 titled Reviewed By: hsharma35, manuelcandales Differential Revision: D75301999 fbshipit-source-id: 164c76c474c87a0c23b02be7fbf8e3f7d1453756
1 parent 07da759 commit 13edf47

File tree

1 file changed

+17
-18
lines changed

1 file changed

+17
-18
lines changed

facto/specdb/function.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import random
8-
97
import torch
8+
from facto.inputgen.utils.random_manager import random_manager as rm
109
from facto.inputgen.variable.type import ScalarDtype
1110
from facto.inputgen.variable.utils import nextdown, nextup
1211

@@ -166,11 +165,11 @@ def factorize(n, length):
166165
factor_list = []
167166
prod = 1
168167
for _ in range(length):
169-
x = random.choice(range(10))
168+
x = rm.get_random().choice(range(10))
170169
factor_list.append(x)
171170
prod *= x
172171
if prod != 0:
173-
i = random.choice(range(length))
172+
i = rm.get_random().choice(range(length))
174173
factor_list[i] = 0
175174
return {tuple(factor_list)}
176175

@@ -180,29 +179,29 @@ def factorize(n, length):
180179
factors = factorize_into_primes(n)
181180
factor_list = [1] * length
182181
for factor in factors:
183-
x = random.choice(range(length))
182+
x = rm.get_random().choice(range(length))
184183
factor_list[x] *= factor
185184
return {tuple(factor_list)}
186185

187186

188187
def valid_view_copy_size(tensor, length):
189188
n = tensor.numel()
190189
valids = factorize(n, length)
191-
factors = random.choice(list(factorize(n, length)))
190+
factors = rm.get_random().choice(list(factorize(n, length)))
192191
if length >= 1:
193192
if n > 0:
194-
x = random.choice(range(length))
193+
x = rm.get_random().choice(range(length))
195194
factor_list = list(factors)
196195
factor_list[x] = -1
197196
valids |= {tuple(factor_list)}
198197
else:
199198
zeros = [i for i in range(length) if factors[i] == 0]
200-
z = random.choice(zeros)
199+
z = rm.get_random().choice(zeros)
201200
factor_list = list(factors)
202201
factor_list[z] = -1
203202
for i in range(length):
204203
if i != z and factors[i] == 0:
205-
factor_list[i] = random.choice(range(1, 10))
204+
factor_list[i] = rm.get_random().choice(range(1, 10))
206205
valids |= {tuple(factor_list)}
207206
return valids
208207

@@ -221,39 +220,39 @@ def invalid_view_copy_size(tensor, length):
221220
if n > 2:
222221
invalids |= factorize(n - 1, length)
223222
if n > 3:
224-
x = random.choice(range(2, n - 1))
223+
x = rm.get_random().choice(range(2, n - 1))
225224
invalids |= factorize(x, length)
226225
invalids |= factorize(n + 1, length)
227226
if n > 0:
228227
invalids |= factorize(2 * n, length)
229228
invalids |= factorize(3 * n, length)
230-
factors = random.choice(list(factorize(n, length)))
229+
factors = rm.get_random().choice(list(factorize(n, length)))
231230
potential_negative = []
232231
for ix, factor in enumerate(factors):
233232
if factor > 1:
234233
potential_negative.append(ix)
235234
if len(potential_negative) >= 1:
236-
x = random.choice(potential_negative)
235+
x = rm.get_random().choice(potential_negative)
237236
factor_list = list(factors)
238237
factor_list[x] = -factors[x]
239238
invalids |= {tuple(factor_list)}
240239
if len(potential_negative) >= 2:
241-
x, y = random.sample(potential_negative, 2)
240+
x, y = rm.get_random().sample(potential_negative, 2)
242241
factor_list = list(factors)
243242
factor_list[x] = -factors[x]
244243
factor_list[y] = -factors[y]
245244
invalids |= {tuple(factor_list)}
246245
if length >= 2:
247-
x, y = random.sample(range(length), 2)
246+
x, y = rm.get_random().sample(range(length), 2)
248247
factor_list = list(factors)
249248
factor_list[x], factor_list[y] = -1, -1
250249
invalids |= {tuple(factor_list)}
251250
if length >= 1 and n == 0:
252251
zeros = [i for i in range(length) if factors[i] == 0]
253-
z = random.choice(zeros)
252+
z = rm.get_random().choice(zeros)
254253
non_z = [i for i in range(length) if i != z]
255254
if len(non_z) >= 1:
256-
x = random.choice([i for i in range(length) if i != z])
255+
x = rm.get_random().choice([i for i in range(length) if i != z])
257256
factor_list = list(factors)
258257
factor_list[x] = -1
259258
invalids |= {tuple(factor_list)}
@@ -289,9 +288,9 @@ def valid_dim_list_helper(tensor, pool, length):
289288

290289
n = max(tensor.dim(), 1)
291290

292-
sample = tuple(random.sample(pool, length))
291+
sample = tuple(rm.get_random().sample(pool, length))
293292
neg_sample = tuple(s - n for s in sample)
294-
mix_sample = tuple(random.choice([s, s - n]) for s in sample)
293+
mix_sample = tuple(rm.get_random().choice([s, s - n]) for s in sample)
295294

296295
return {sample, neg_sample, mix_sample}
297296

0 commit comments

Comments
 (0)