Skip to content

Commit 1dd6b76

Browse files
Revert "[1/N] Remove unused loop variables (pytorch#166258)"
This reverts commit 76b2c37. Reverted pytorch#166258 on behalf of https://github.com/atalman due to breaks test/distributed/test_serialization.py::TestSerialization::test_weights_only [GH job link](https://github.com/pytorch/pytorch/actions/runs/18894311802/job/53929321703) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/76b2c37045e52540ec51e967aa7b6436a6b9b174) ([comment](pytorch#166258 (comment)))
1 parent 284716a commit 1dd6b76

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+103
-73
lines changed

torch/_dynamo/eval_frame.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1791,7 +1791,7 @@ def produce_matching(
17911791
for i, val in enumerate(sources):
17921792
dict_of_source_vals[id(val)] = i
17931793

1794-
for val in candidates:
1794+
for i, val in enumerate(candidates):
17951795
if isinstance(val, tuple(common_constant_types)):
17961796
matched_elements_positions.append(None)
17971797
elif id(val) not in dict_of_source_vals:

torch/_dynamo/guards.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def visit_dict_manager(node: DictGuardManager) -> bool:
317317
is_diff_guard_node = (
318318
node.get_source() in self.diff_guard_sources or node.fail_count() > 0
319319
)
320-
for _idx, (key_mgr, val_mgr) in sorted(
320+
for idx, (key_mgr, val_mgr) in sorted(
321321
node.get_key_value_managers().items()
322322
):
323323
is_diff_guard_node |= visit(key_mgr) | visit(val_mgr)
@@ -440,15 +440,17 @@ def visit_dict_manager(node: DictGuardManager) -> list[GuardManager]:
440440
is_subtree_tag_safe = True
441441

442442
# Recurse to get the tag safe roots from subtree.
443-
for _idx, (key_mgr, val_mgr) in sorted(
443+
for idx, (key_mgr, val_mgr) in sorted(
444444
node.get_key_value_managers().items()
445445
):
446446
if key_mgr is not None:
447447
visit(key_mgr)
448448
if val_mgr is not None:
449449
tag_safe_roots.extend(visit(val_mgr))
450450

451-
for key_mgr, val_mgr in node.get_key_value_managers().values():
451+
for idx, (key_mgr, val_mgr) in sorted(
452+
node.get_key_value_managers().items()
453+
):
452454
if key_mgr:
453455
is_subtree_tag_safe &= key_mgr.is_tag_safe()
454456

torch/_dynamo/variables/optimizer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,9 @@ def mark_static(x):
289289
params_vt = group_vt.getitem_const(tx, ConstantVariable.create("params"))
290290
all_static = True
291291
non_static_grads = []
292-
for p, p_vt in zip(group["params"], params_vt.unpack_var_sequence(tx)):
292+
for p_ind, (p, p_vt) in enumerate(
293+
zip(group["params"], params_vt.unpack_var_sequence(tx))
294+
):
293295
param_source = p_vt.source
294296
self.tensor_to_source[p] = param_source
295297
grad_source = GradSource(
@@ -320,12 +322,12 @@ def mark_static(x):
320322

321323
# We have to again iterate over the state dict to collect the
322324
# tensor_to_source dict. This is used for the finalizer.
323-
for idx, value in enumerate(self.value.state.values()):
325+
for idx, (p, value) in enumerate(self.value.state.items()):
324326
p_state_source = DictGetItemSource(
325327
state_source, ConstDictKeySource(state_source, idx)
326328
)
327329
tx.output.guard_on_key_order.add(p_state_source)
328-
for inner_idx, v in enumerate(value.values()):
330+
for inner_idx, (k, v) in enumerate(value.items()):
329331
if (
330332
isinstance(v, torch.Tensor)
331333
and v not in self.grad_to_source

torch/_functorch/_aot_autograd/collect_metadata_analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def inner(*flat_args):
240240

241241
# Inspect the state of the input tensor functional wrapper to detect input mutation info
242242
# If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version
243-
for arg, f_arg in zip(flat_args, flat_f_args):
243+
for i, (arg, f_arg) in enumerate(zip(flat_args, flat_f_args)):
244244
# NB: Mutation of non-contiguous tensor subclass input can result in a mismatch in
245245
# strides between the functionalized arg inner tensors and non-functionalized arg inner
246246
# tensors. This is a problem as the inner tensor stride change may not be reflected

torch/_functorch/_aot_autograd/runtime_wrappers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2041,7 +2041,7 @@ def maybe_coerce(x):
20412041

20422042
assert len(meta.attrs) == len(runtime_subclass_keys)
20432043
leaves = []
2044-
for attr, attr_meta in meta.attrs.items():
2044+
for i, (attr, attr_meta) in enumerate(meta.attrs.items()):
20452045
elem = getattr(x, attr)
20462046
new_elem, elem_leaves = AOTDispatchAutograd.process_runtime_tangent(
20472047
elem, attr_meta

torch/_functorch/_aot_autograd/subclass_parametrization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def unwrap_tensor_subclass_parameters(module: torch.nn.Module) -> torch.nn.Modul
9898
module, name, UnwrapTensorSubclass()
9999
)
100100

101-
for child in module.children():
101+
for name, child in module.named_children():
102102
unwrap_tensor_subclass_parameters(child)
103103

104104
return module

torch/_functorch/partitioners.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,7 +1481,9 @@ def get_sample_rng_state(device: Optional[torch.device]):
14811481
)
14821482
)
14831483

1484-
for rng_count, node_pair in enumerate(recomputable_rng_ops_map.values()):
1484+
for rng_count, (base_node, node_pair) in enumerate(
1485+
recomputable_rng_ops_map.items()
1486+
):
14851487
# Step 2 - Modify the fwd pass such that
14861488
fw_node = node_pair["fwd"]
14871489
bw_node = node_pair["bwd"]
@@ -2712,7 +2714,9 @@ def thread_graphsafe_rng_from_hops(module, is_backward):
27122714
subgraph = getattr(module, hop_node.args[0].target)
27132715
if isinstance(subgraph, fx.GraphModule):
27142716
new_rng_inputs = []
2715-
for placeholder_node in subgraph.graph.find_nodes(op="placeholder"):
2717+
for idx, placeholder_node in enumerate(
2718+
subgraph.graph.find_nodes(op="placeholder")
2719+
):
27162720
if rng_string in placeholder_node.name:
27172721
# Found a rng state placeholder in the hop graph, lets add
27182722
# the corresponding node in the outer graph

torch/_functorch/pyfunctorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def temporarily_restore_interpreter_stack(stack):
116116
pushed.append(s)
117117
yield
118118
finally:
119-
for _ in reversed(pushed):
119+
for s in reversed(pushed):
120120
# TODO: would be nice to assert that the layers are the same, but
121121
# Python object identity is not preserved
122122
pop_dynamic_layer_stack()

torch/_higher_order_ops/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,7 @@ def diff_tensor_meta(
907907
try:
908908
if val1 != val2:
909909
pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
910-
except GuardOnDataDependentSymNode:
910+
except GuardOnDataDependentSymNode as _:
911911
pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
912912
continue
913913
return pair_diffs
@@ -1197,7 +1197,7 @@ def wrapped_fn(*flat_args):
11971197

11981198
# call_op preserves ordering of proxies via schema
11991199
materialized_args = []
1200-
for i, proxy in enumerate(arg_proxies):
1200+
for i, (proxy, arg) in enumerate(zip(arg_proxies, schema.arguments)):
12011201
if (
12021202
isinstance(proxy, torch.fx.Node)
12031203
and proxy.op == "get_attr"

torch/_higher_order_ops/while_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def _validate_cond_output(pred):
316316

317317
if stack_output:
318318
outs: list[torch.Tensor] = []
319-
for out in outputs:
319+
for i, out in enumerate(outputs):
320320
outs.append(torch.stack(out, dim=0))
321321
return tuple(outs)
322322

0 commit comments

Comments
 (0)