Skip to content

Commit b074cba

Browse files
williamwen42pytorchmergebot
authored andcommitted
[dynamo] allow resume functions to have name in both freevars and varnames (pytorch#161544)
fixes pytorch#161542 Differential Revision: [D81073109](https://our.internmc.facebook.com/intern/diff/D81073109) Pull Request resolved: pytorch#161544 Approved by: https://github.com/StrongerXi, https://github.com/anijain2305
1 parent 80bf883 commit b074cba

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

test/dynamo/test_repros.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7129,6 +7129,24 @@ def fn(x):
71297129
0, sys.monitoring.events.PY_START, old_callback
71307130
)
71317131

7132+
def test_312_local_cell_overlap(self):
7133+
keys = range(10)
7134+
allowed = [0, 1, 2, 3]
7135+
7136+
def fn(x):
7137+
x = x + 1
7138+
torch._dynamo.graph_break()
7139+
key = [key for key in keys if key in allowed]
7140+
7141+
def inner():
7142+
nonlocal key
7143+
7144+
return x + key[0]
7145+
7146+
self.assertEqual(
7147+
fn(torch.ones(3)), torch.compile(fn, backend="eager")(torch.ones(3))
7148+
)
7149+
71327150
def test_unbind_copy_out(self):
71337151
def f(eye, out):
71347152
torch.unbind_copy(eye, out=out)

torch/_dynamo/resume_execution.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -369,11 +369,7 @@ def update(
369369
code_options["co_varnames"] = tuple(
370370
args
371371
+ [v for v in argnames_null if v not in args]
372-
+ [
373-
v
374-
for v in code_options["co_varnames"]
375-
if v not in args and v not in freevars
376-
]
372+
+ [v for v in code_options["co_varnames"] if v not in args]
377373
+ [IS_TRACING_RESUME_PROLOGUE_VARNAME]
378374
)
379375
code_options["co_flags"] = code_options["co_flags"] & ~(

0 commit comments

Comments
 (0)