Skip to content

Commit c450ded

Browse files
Do not use global Numpy RNG (#461)
This patch switches from using the default numpy RNG to explicitly using a specific RNG. This is reccomended by numpy and identified by Ruff's NPY002 rule. This keeps things out of global state which makes reasoning about RNG a bit easier, although the code is a bit more verbose. Fixes #460.
1 parent 079d4c7 commit c450ded

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

compiler_opt/rl/imitation_learning/generate_bc_trajectories_lib.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,8 @@ def __init__(
367367
self._explore_on_features = explore_on_features
368368
logging.info('Reward key in exploration worker: %s', self._reward_key)
369369

370+
self._rng = np.random.default_rng()
371+
370372
def compile_module(
371373
self,
372374
policy: Callable[[time_step.TimeStep | None], np.ndarray],
@@ -545,7 +547,7 @@ def explore_at_state_generator(
545547
distr_logits[replay_prefix[explore_step]] = -np.inf
546548
if all(-np.inf == logit for logit in distr_logits):
547549
break
548-
replay_prefix[explore_step] = np.random.choice(
550+
replay_prefix[explore_step] = self._rng.choice(
549551
range(distr_logits.shape[0]), p=scipy.special.softmax(distr_logits))
550552
base_policy = ExplorationWithPolicy(
551553
replay_prefix,

compiler_opt/tools/generate_vocab.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,13 @@ def _parser_fn(serialized_proto):
120120
return _parser_fn
121121

122122

123-
def _generate_vocab(feature_values_arrays, feature_name):
123+
def _generate_vocab(feature_values_arrays, feature_name,
124+
rng: np.random.Generator):
124125
"""Downsample and generate vocab using brute force method."""
125126
feature_values = np.concatenate(feature_values_arrays)
126127
sample_length = math.floor(
127128
np.shape(feature_values)[0] * FLAGS.sampling_fraction)
128-
values = np.random.choice(feature_values, sample_length, replace=False)
129+
values = rng.choice(feature_values, sample_length, replace=False)
129130
bin_edges = np.quantile(values, np.linspace(0, 1, FLAGS.num_buckets))
130131
filename = os.path.join(FLAGS.output_dir, f'{feature_name}.buckets')
131132
with open(filename, 'w', encoding='utf-8') as f:
@@ -168,14 +169,13 @@ def main(_) -> None:
168169
dataset = dataset.map(parser_fn, num_parallel_calls=tf.data.AUTOTUNE)
169170
data_list = np.array(list(dataset.as_numpy_iterator()), dtype=object)
170171
data_list = data_list.swapaxes(0, 1)
172+
rng = np.random.default_rng()
171173

172174
with mp.Pool(FLAGS.parallelism) as pool:
173175
feature_names = sorted(sequence_features)
174176
for i, feature_values_arrays in enumerate(data_list):
175-
pool.apply_async(_generate_vocab, (
176-
feature_values_arrays,
177-
feature_names[i],
178-
))
177+
pool.apply_async(_generate_vocab,
178+
(feature_values_arrays, feature_names[i], rng))
179179
pool.close()
180180
pool.join()
181181

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.ruff]
22
line-length = 103
3-
lint.select = [ "C40", "C9", "E", "F", "PERF", "UP", "W", "YTT" ]
3+
lint.select = [ "C40", "C9", "E", "F", "PERF", "UP", "W", "YTT", "NPY", "PD" ]
44
lint.ignore = [ "E722", "E731", "F401", "PERF203" ]
55
lint.mccabe.max-complexity = 18
66
target-version = "py310"

0 commit comments

Comments
 (0)