Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 7a7545e

Browse files
authored
Disbale torchdynamo on AOT Autograd generated graphs (#662)
1 parent 2eb181f commit 7a7545e

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

functorch/_src/aot_autograd.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
from .named_members_polyfill import _named_parameters, _named_buffers
1313
from typing import Callable, List, Dict, Any, Tuple, Optional
1414

15+
try:
16+
from torchdynamo import disable as disable_torchdynamo
17+
except ImportError:
18+
def disable_torchdynamo(x):
19+
return x
20+
1521
pytree._register_pytree_node(
1622
immutable_collections.immutable_list,
1723
lambda x: (list(x), None),
@@ -129,6 +135,7 @@ def create_aot_autograd_function(
129135

130136
class CompiledFunction(torch.autograd.Function):
131137
@staticmethod
138+
@disable_torchdynamo
132139
def forward(ctx, *flat_tensor_args):
133140
nonlocal compiled_fw, compiled_bw, num_outs
134141
if compiled_fw is None:
@@ -163,6 +170,7 @@ def forward(ctx, *flat_tensor_args):
163170
return tuple(fw_outs[0:num_outs])
164171

165172
@staticmethod
173+
@disable_torchdynamo
166174
def backward(ctx, *flat_args):
167175
contiguous_args = [t.contiguous() for t in flat_args]
168176
# contiguous_args = [t for t in flat_args]

0 commit comments

Comments
 (0)