Skip to content

Commit 8cb41b8

Browse files
ruff rules for comprehensions and performance (#420)
* ruff rules for comprehensions and performance * uv tool run yapf -irp . --------- Co-authored-by: Aiden Grossman <[email protected]>
1 parent 95a4050 commit 8cb41b8

22 files changed

+100
-101
lines changed

compiler_opt/benchmark/benchmark_chromium.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,9 @@ def main(_):
206206
with open(test_description, encoding='UTF-8') as test_description_file:
207207
print(test_description)
208208
test_descriptions.append(json.load(test_description_file))
209-
test_executables = []
210-
for test_description in test_descriptions:
211-
test_executables.append(test_description['executable'])
209+
test_executables = [
210+
test_description['executable'] for test_description in test_descriptions
211+
]
212212

213213
if FLAGS.compile_llvm:
214214
benchmarking_utils.build_llvm(FLAGS.model_path, FLAGS.llvm_use_incremental,

compiler_opt/benchmark/benchmark_report_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,8 @@ def test_loading(self):
8181
'PerfCounter_1': [50],
8282
}
8383
})
84-
self.assertSetEqual(report.names(), set(['BM_A', 'BM_B']))
85-
self.assertSetEqual(report.counters(),
86-
set(['PerfCounter_0', 'PerfCounter_1']))
84+
self.assertSetEqual(report.names(), {'BM_A', 'BM_B'})
85+
self.assertSetEqual(report.counters(), {'PerfCounter_0', 'PerfCounter_1'})
8786
self.assertEqual(
8887
report.counter_means('BM_A', 'PerfCounter_0'),
8988
(10.488088481701517, 0.7071067811865476))

compiler_opt/benchmark/filter_tests.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,7 @@ def main(_):
6666
test_suite_description = json.load(test_description_file)
6767
test_outputs = gtest_executable_utils.run_test_suite(
6868
test_suite_description, FLAGS.executable_path, [], FLAGS.num_threads)
69-
test_list = []
70-
for test_output in test_outputs:
71-
test_list.append(test_output['name'])
69+
test_list = [test_output['name'] for test_output in test_outputs]
7270
# copy the old test suite and just replace the tests array
7371
new_test_suite_description = test_suite_description
7472
new_test_suite_description['tests'] = test_list

compiler_opt/benchmark/gtest_executable_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,8 @@ def run_test_suite(test_suite_description: Dict[str, List[str]],
125125
if num_threads is None:
126126
num_threads = 1
127127

128-
test_descriptions = []
129-
for test in test_suite_description['tests']:
130-
test_descriptions.append((test_executable, test, perf_counters))
128+
test_descriptions = [(test_executable, test, perf_counters)
129+
for test in test_suite_description['tests']]
131130

132131
test_data_output = Parallel(n_jobs=num_threads)(
133132
delayed(run_and_parse)(test_description)

compiler_opt/distributed/local/local_worker_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def _msg_pump(self):
184184
# clear out pending futures and mark ourselves as "stopped" by null-ing
185185
# the map
186186
with self._lock:
187-
for _, v in self._map.items():
187+
for v in self._map.values():
188188
v.set_exception(concurrent.futures.CancelledError())
189189
self._map = None
190190

compiler_opt/es/blackbox_learner.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,12 @@ def __init__(self,
168168

169169
def _get_perturbations(self) -> List[npt.NDArray[np.float32]]:
170170
"""Get perturbations for the model weights."""
171-
perturbations = []
172171
rng = np.random.default_rng(seed=self._seed)
173-
for _ in range(self._config.total_num_perturbations):
174-
perturbations.append(
175-
rng.normal(size=len(self._model_weights)) *
176-
self._config.precision_parameter)
177-
return perturbations
172+
return [
173+
rng.normal(size=len(self._model_weights)) *
174+
self._config.precision_parameter
175+
for _ in range(self._config.total_num_perturbations)
176+
]
178177

179178
def _update_model(self, perturbations: List[npt.NDArray[np.float32]],
180179
rewards: List[float]) -> None:
@@ -276,10 +275,10 @@ def run_step(self, pool: FixedWorkerPool) -> None:
276275
p for p in initial_perturbations for p in (p, -p)
277276
]
278277

279-
perturbations_as_policies = []
280-
for perturbation in initial_perturbations:
281-
perturbations_as_policies.append(
282-
self._get_policy_from_perturbation(perturbation))
278+
perturbations_as_policies = [
279+
self._get_policy_from_perturbation(perturbation)
280+
for perturbation in initial_perturbations
281+
]
283282

284283
results = self._evaluator.get_results(pool, perturbations_as_policies)
285284
rewards = self._evaluator.get_rewards(results)

compiler_opt/es/regalloc_trace/regalloc_trace_worker.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,12 @@ def _build_corpus(self, modules: Collection[corpus.ModuleSpec],
110110
else:
111111
tflite_policy_dir = None
112112

113-
compile_futures = []
114113
with concurrent.futures.ThreadPoolExecutor(
115114
max_workers=self._thread_count) as thread_pool:
116-
for module in modules:
117-
compile_futures.append(
118-
thread_pool.submit(self._compile_module, module, output_directory,
119-
tflite_policy_dir))
115+
compile_futures = [
116+
thread_pool.submit(self._compile_module, module, output_directory,
117+
tflite_policy_dir) for module in modules
118+
]
120119

121120
for future in compile_futures:
122121
if future.exception() is not None:

compiler_opt/rl/corpus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def __init__(self,
342342
'{context.module_full_path}') + additional_flags
343343

344344
# don't use add/remove for replace
345-
add_keys = set(k.split('=', maxsplit=1)[0] for k in additional_flags)
345+
add_keys = {k.split('=', maxsplit=1)[0] for k in additional_flags}
346346
if add_keys.intersection(
347347
set(replace_flags)) or set(delete_flags).intersection(
348348
set(replace_flags)) or add_keys.intersection(set(delete_flags)):

compiler_opt/rl/data_reader.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,12 @@ def _parser_fn(serialized_proto):
4444
# and stored in the feature list.
4545
context_features = {}
4646
# pylint: disable=g-complex-comprehension
47-
sequence_features = dict(
48-
(tensor_spec.name,
49-
tf.io.FixedLenSequenceFeature(
50-
shape=tensor_spec.shape, dtype=tensor_spec.dtype))
51-
for tensor_spec in agent_cfg.time_step_spec.observation.values())
47+
sequence_features = {
48+
tensor_spec.name:
49+
tf.io.FixedLenSequenceFeature(
50+
shape=tensor_spec.shape, dtype=tensor_spec.dtype)
51+
for tensor_spec in agent_cfg.time_step_spec.observation.values()
52+
}
5253
sequence_features[
5354
agent_cfg.action_spec.name] = tf.io.FixedLenSequenceFeature(
5455
shape=agent_cfg.action_spec.shape,

compiler_opt/rl/imitation_learning/generate_bc_trajectories_lib.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -561,10 +561,7 @@ def explore_at_state_generator(
561561
yield base_seq, base_policy
562562

563563
def _build_replay_prefix_list(self, seq_ex):
564-
ret_list = []
565-
for int_list in seq_ex:
566-
ret_list.append(int_list.int64_list.value[0])
567-
return ret_list
564+
return [int_list.int64_list.value[0] for int_list in seq_ex]
568565

569566
def _create_timestep(self, curr_obs_dict: env.TimeStep):
570567
curr_obs = curr_obs_dict.obs

0 commit comments

Comments
 (0)