Skip to content

Commit a159920

Browse files
wdziurdzjansel
andauthored
Return ConstantVariable(None) from WithExitFunctionVariable.exit to prevent NoneType crash inside autocast exception path (pytorch#153612)
Return ConstantVariable(None) from WithExitFunctionVariable.exit to prevent NoneType crash inside autocast exception path (pytorch#152503) Copy of pytorch#152013 with PR time benchmarks updated (regressions seem unrelated) Pull Request resolved: pytorch#152503 Approved by: https://github.com/anijain2305, https://github.com/Skylion007 Co-authored-by: Jason Ansel <[email protected]>
1 parent 6f2f41c commit a159920

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

test/dynamo/test_exceptions.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,28 @@ def fn(x):
128128
res = opt_fn(x)
129129
self.assertEqual(ref, res)
130130

131+
def test_autocast_with_exception(self):
132+
class Optimizer(torch.autograd.Function):
133+
@staticmethod
134+
def forward(ctx, x):
135+
raise NotImplementedError("Not implemented")
136+
137+
@staticmethod
138+
def backward(ctx, grad_out):
139+
return grad_out
140+
141+
@torch.compile
142+
def f(x: torch.Tensor):
143+
try:
144+
with torch.autocast(device_type="cpu", dtype=None):
145+
Optimizer.apply(x)
146+
except NotImplementedError:
147+
return x + 1
148+
149+
inp = torch.ones(3)
150+
out = f(inp)
151+
self.assertTrue(torch.equal(out, inp + 1))
152+
131153
@make_dynamo_test
132154
def test_propagate_exception_inside_ctx_manager(self):
133155
@contextlib.contextmanager

torch/_dynamo/variables/ctx_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,7 @@ def exit(self, tx: "InstructionTranslator", *args):
839839
tx.output.create_node(
840840
"call_function", torch.amp._exit_autocast, (self.state.proxy,), {}
841841
)
842+
return variables.ConstantVariable.create(None)
842843

843844
def enter(self, tx):
844845
ctx = torch.amp._enter_autocast(*self.target_values)

0 commit comments

Comments
 (0)