Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 12553c5

Browse files
authored
Disable autocast (#794)
* Disable autocast * Add global flag * Add a test
1 parent 347334c commit 12553c5

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

functorch/_src/aot_autograd.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ class CompiledFunction(torch.autograd.Function):
160160
@disable_torchdynamo
161161
def forward(ctx, *flat_tensor_args):
162162
nonlocal compiled_fw, compiled_bw, num_outs
163+
# Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
164+
# TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
165+
old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
163166
if compiled_fw is None:
164167
with preserve_rng_state():
165168
# Set input tensors that require grad to leaves
@@ -194,15 +197,20 @@ def forward(ctx, *flat_tensor_args):
194197
compiled_bw = bw_compiler(bw_module, bw_args)
195198
else:
196199
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
200+
torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
197201
ctx.save_for_backward(*fw_outs[num_outs:])
198202
return tuple(fw_outs[0:num_outs])
199203

200204
@staticmethod
201205
@disable_torchdynamo
202206
def backward(ctx, *flat_args):
207+
# Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
208+
# TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
209+
old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
203210
contiguous_args = [t.contiguous() for t in flat_args]
204211
# contiguous_args = [t for t in flat_args]
205212
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
213+
torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
206214
return tuple(out)
207215

208216
return CompiledFunction

test/test_pythonkey.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from functorch.compile import (
2222
nnc_jit, compiled_function, compiled_module,
2323
min_cut_rematerialization_partition, aot_function, aot_module, decomposition_table, nop,
24-
num_of_recompilations, default_partition, default_decompositions
24+
num_of_recompilations, default_partition, default_decompositions, memory_efficient_fusion,
2525
)
2626

2727
from torch.testing._internal.common_device_type import ops
@@ -564,6 +564,22 @@ def fn(x):
564564
assert torch.allclose(ref, res)
565565

566566

567+
class TestAutocast(TestCase):
568+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
569+
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
570+
def test_autocast(self):
571+
mod = torchvision.models.resnet18().cuda()
572+
mod.train()
573+
574+
x = torch.randn(16, 3, 32, 32, device="cuda")
575+
aot_mod = memory_efficient_fusion(mod)
576+
577+
# Ensure that AOT Autograd works with AMP
578+
with torch.cuda.amp.autocast(True):
579+
res = aot_mod(x)
580+
res.sum().backward()
581+
582+
567583
only_for = ("cpu")
568584
instantiate_device_type_tests(
569585
TestPythonKey,

0 commit comments

Comments
 (0)