Skip to content

Commit c724f00

Browse files
cyyeverpytorchmergebot
authored andcommitted
[2/N] Use key in dict for existence checks (pytorch#167174)
This PR uses `key in dict` expressions for existence checks of dict elements in Python code. This operation is more efficient than `key in dict.keys()`. Pull Request resolved: pytorch#167174 Approved by: https://github.com/mlazos
1 parent a51208c commit c724f00

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+72
-83
lines changed

test/cpp/api/init_baseline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def run(initializer):
6464

6565
def main():
6666
initializer_parameter_map = {}
67-
for initializer in INITIALIZERS.keys():
67+
for initializer in INITIALIZERS:
6868
sys.stderr.write(f"Evaluating {initializer} ...\n")
6969
initializer_parameter_map[initializer] = run(initializer)
7070

test/cpp/api/optim_baseline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def main():
130130
options = parser.parse_args()
131131

132132
optimizer_parameter_map = {}
133-
for optimizer in OPTIMIZERS.keys():
133+
for optimizer in OPTIMIZERS:
134134
sys.stderr.write(f"Evaluating {optimizer} ...\n")
135135
optimizer_parameter_map[optimizer] = run(
136136
optimizer, options.iterations, options.sample_every

test/distributed/checkpoint/test_hf_safetensor_e2e.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def test_quantized_checkpoint_loading(self) -> None:
208208

209209
# Create model.safetensors.index.json with weight mapping
210210
weight_map = {}
211-
for key in quantized_checkpoint.keys():
211+
for key in quantized_checkpoint:
212212
weight_map[key] = "model.safetensors"
213213

214214
index_data = {
@@ -245,7 +245,7 @@ def test_quantized_checkpoint_loading(self) -> None:
245245
sorted(original_tensors.keys()), sorted(state_dict_to_load.keys())
246246
)
247247

248-
for tensor_name in original_tensors.keys():
248+
for tensor_name in original_tensors:
249249
original = original_tensors[tensor_name]
250250
loaded = state_dict_to_load[tensor_name]
251251

test/distributed/fsdp/test_fsdp_mixed_precision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def _run_test_mixed_precision_e2e(
498498
for name, tensor in state_dict.items():
499499
# Parameters and buffers are checkpointed in their
500500
# original dtypes, which may be different.
501-
if name in named_buffers.keys():
501+
if name in named_buffers:
502502
self.assertEqual(tensor.dtype, _BUFFER_ORIG_DTYPE)
503503
else:
504504
self.assertEqual(

test/distributed/test_c10d_common.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,9 +1189,7 @@ def _test_sequence_num_incremented(self, process_group, ranks):
11891189
self.assertEqual(len(set(rank_to_seq_num.values())), 2)
11901190
self.assertEqual(rank_to_seq_num[0], rank_to_seq_num[2])
11911191
expected_same = {
1192-
rank_to_seq_num[i]
1193-
for i in rank_to_seq_num.keys()
1194-
if i not in [0, 2]
1192+
rank_to_seq_num[i] for i in rank_to_seq_num if i not in [0, 2]
11951193
}
11961194
self.assertEqual(len(expected_same), 1)
11971195
self.assertEqual(rank_to_seq_num[0] + 1, rank_to_seq_num[1])
@@ -1558,7 +1556,7 @@ def test_debug_level(self):
15581556
}
15591557
invalid_debug_modes = ["foo", 0, 1, -1]
15601558

1561-
for mode in mapping.keys():
1559+
for mode in mapping:
15621560
os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode)
15631561
dist.set_debug_level_from_env()
15641562
set_debug_mode = dist.get_debug_level()

test/distributed/test_local_tensor.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,14 @@ def test_basic_arithmetic_operations(self):
128128
self.assertEqual(len(result_add._local_tensors), 2)
129129

130130
# Verify the operation was applied to each local tensor
131-
for rank in identical_local_tensors.keys():
131+
for rank in identical_local_tensors:
132132
expected = identical_local_tensors[rank] + identical_local_tensors[rank]
133133
self.assertEqual(result_add._local_tensors[rank], expected)
134134

135135
# Test multiplication
136136
result_mul = lt1 * 2.0
137137
self.assertIsInstance(result_mul, LocalTensor)
138-
for rank in identical_local_tensors.keys():
138+
for rank in identical_local_tensors:
139139
expected = identical_local_tensors[rank] * 2.0
140140
self.assertEqual(result_mul._local_tensors[rank], expected)
141141

@@ -163,7 +163,7 @@ def test_mixed_operations_with_regular_tensors(self):
163163
result = lt + regular_tensor
164164
self.assertIsInstance(result, LocalTensor)
165165

166-
for rank in identical_local_tensors.keys():
166+
for rank in identical_local_tensors:
167167
expected = identical_local_tensors[rank] + regular_tensor
168168
self.assertEqual(result._local_tensors[rank], expected)
169169

@@ -212,14 +212,14 @@ def test_collectives_within_local_tensor_mode(self):
212212
dist.all_reduce(lt_sum, group=fake_pg)
213213

214214
expected_sum = torch.tensor([[6.0, 8.0], [10.0, 12.0]])
215-
for rank in test_tensors.keys():
215+
for rank in test_tensors:
216216
self.assertEqual(lt_sum._local_tensors[rank], expected_sum)
217217

218218
# Test broadcast within mode
219219
lt_broadcast = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
220220
dist.broadcast(lt_broadcast, src=0, group=fake_pg)
221221

222-
for rank in test_tensors.keys():
222+
for rank in test_tensors:
223223
self.assertEqual(lt_broadcast._local_tensors[rank], test_tensors[0])
224224

225225
# Test that regular operations still work
@@ -293,21 +293,21 @@ def test_collective_reduction_operations(self):
293293
lt_sum = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
294294
dist.all_reduce(lt_sum, op=dist.ReduceOp.SUM, group=fake_pg)
295295
expected_sum = torch.tensor([[6.0, 7.0], [6.0, 15.0]]) # Sum of all tensors
296-
for rank in test_tensors.keys():
296+
for rank in test_tensors:
297297
self.assertEqual(lt_sum._local_tensors[rank], expected_sum)
298298

299299
# Test MAX reduction
300300
lt_max = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
301301
dist.all_reduce(lt_max, op=dist.ReduceOp.MAX, group=fake_pg)
302302
expected_max = torch.tensor([[3.0, 4.0], [3.0, 6.0]]) # Max across all tensors
303-
for rank in test_tensors.keys():
303+
for rank in test_tensors:
304304
self.assertEqual(lt_max._local_tensors[rank], expected_max)
305305

306306
# Test MIN reduction
307307
lt_min = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
308308
dist.all_reduce(lt_min, op=dist.ReduceOp.MIN, group=fake_pg)
309309
expected_min = torch.tensor([[1.0, 1.0], [1.0, 4.0]]) # Min across all tensors
310-
for rank in test_tensors.keys():
310+
for rank in test_tensors:
311311
self.assertEqual(lt_min._local_tensors[rank], expected_min)
312312

313313
def test_all_reduce_collective(self):
@@ -328,7 +328,7 @@ def test_all_reduce_collective(self):
328328

329329
# Verify all ranks have the sum of all tensors (after adding 1 to each)
330330
expected_sum = torch.tensor([[114.0, 225.0, 336.0], [447.0, 558.0, 669.0]])
331-
for rank in different_tensors.keys():
331+
for rank in different_tensors:
332332
self.assertEqual(lt_sum._local_tensors[rank], expected_sum)
333333

334334
def test_broadcast_collective(self):
@@ -348,7 +348,7 @@ def test_broadcast_collective(self):
348348

349349
# Verify all ranks have rank 1's original tensor
350350
expected_broadcast = different_tensors[1]
351-
for rank in different_tensors.keys():
351+
for rank in different_tensors:
352352
self.assertEqual(lt_broadcast._local_tensors[rank], expected_broadcast)
353353

354354
def test_all_gather_collective(self):

test/dynamo/test_subclasses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4036,7 +4036,7 @@ def backend(gm, args):
40364036

40374037
@parametrize(
40384038
"nt_view_name",
4039-
[k for k in VIEW_TEST_CASES.keys() if k != "subclass_dense_subclass_dense"],
4039+
[k for k in VIEW_TEST_CASES if k != "subclass_dense_subclass_dense"],
40404040
)
40414041
def test_inputs_to_compiled_fn_are_views(self, nt_view_name):
40424042
self._input_view_test(nt_view_name)

test/functorch/xfail_suggester.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def parse_namespace(base):
7373
"sparse_": "sparse",
7474
"special_": "special",
7575
}
76-
for heading in mappings.keys():
76+
for heading in mappings:
7777
if base.startswith(heading):
7878
return mappings[heading], base[len(heading) :]
7979
return None, base

test/inductor/test_compiled_optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def build_opt_kwarg_db():
320320
continue
321321

322322
if has_tensor_lr:
323-
for scheduler_cls in LR_SCHEDULER_TO_KWARGS.keys():
323+
for scheduler_cls in LR_SCHEDULER_TO_KWARGS:
324324
name_w_scheduler = name + f"_{scheduler_cls.__name__.lower()}"
325325
compiled_opt_db.append(
326326
(

test/profiler/test_profiler.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -916,8 +916,7 @@ def judge(expected_event_count, prof):
916916
)
917917
for key, count in expected_event_count.items():
918918
self.assertTrue(
919-
(key in actual_event_count.keys())
920-
and (count == actual_event_count[key])
919+
(key in actual_event_count) and (count == actual_event_count[key])
921920
)
922921

923922
with _profile(use_kineto=kineto_available()) as prof:
@@ -1406,10 +1405,7 @@ def test_profiler_fwd_bwd_link(self):
14061405
s_ts_2 = flow_s_to_ts[2]
14071406
f_ts_2 = flow_f_to_ts[2]
14081407
self.assertTrue(
1409-
all(
1410-
ts in ts_to_name.keys()
1411-
for ts in [s_ts_1, f_ts_1, s_ts_2, f_ts_2]
1412-
)
1408+
all(ts in ts_to_name for ts in [s_ts_1, f_ts_1, s_ts_2, f_ts_2])
14131409
)
14141410
self.assertTrue(
14151411
ts_to_name[s_ts_1] == "aten::binary_cross_entropy_with_logits"

0 commit comments

Comments
 (0)