Skip to content

Commit 8ff9485

Browse files
angelayipytorchmergebot
authored andcommitted
[export] Update unflattening dynamo.disable (pytorch#161306)
Summary: Doing inline disabling causes recompiles with the reason "Cache line invalidated because L['___stack0'] got deallocated" Test Plan: CI Rollback Plan: Differential Revision: D80816956 Pull Request resolved: pytorch#161306 Approved by: https://github.com/pianpwk
1 parent b074cba commit 8ff9485

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

torch/export/unflatten.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -650,10 +650,7 @@ def process_forward_inputs(self, *args, **kwargs):
650650
return flat_args
651651

652652
def forward(self, *args, **kwargs):
653-
flat_args = torch._dynamo.disable(
654-
self.process_forward_inputs,
655-
reason="do not trace into preprocessing the inputs",
656-
)(*args, **kwargs)
653+
flat_args = self.process_forward_inputs(*args, **kwargs)
657654
signature = self.module_call_graph[0].signature
658655

659656
if is_fx_symbolic_tracing():
@@ -775,7 +772,17 @@ def unflatten(
775772
hierarchy as the original eager module pre-export.
776773
"""
777774
module = _remove_effect_tokens(module)
778-
return UnflattenedModule(module, flat_args_adapter)
775+
m = UnflattenedModule(module, flat_args_adapter)
776+
777+
# Disable process_forward_inputs as the adapter has many
778+
# non-dynamo-traceable behavior.
779+
m.process_forward_inputs = torch._dynamo.disable( # type: ignore[method-assign]
780+
m.process_forward_inputs,
781+
reason="do not trace into preprocessing the inputs",
782+
recursive=True,
783+
)
784+
785+
return m
779786

780787

781788
def _inplace_buffer_and_input_mutations(

0 commit comments

Comments
 (0)