Skip to content

Conversation

shino16
Copy link
Collaborator

@shino16 shino16 commented Oct 3, 2025

Fixes #2539 and other issues it caused.

What this fixes

torch.compile skips the lowering of GraphModule that Thunder's splitter emits. This PR passes the GraphModule to torch._inductor.compile instead, which then calls torch._inductor.compile_fx.compile_fx, to make sure Inductor lowers it.

Although it is found in #2527 (comment) that wrapping the GraphModule in a new Module is sufficient, it was found in #2551 (comment) that Dynamo cannot trace the enter/exit of torch.autocast() region.

Why this is more than just replacing the torch.compile call

(A) Graph Splitting

Inductor assumes that the output node of the GraphModule has the form return (t0, ..., tN). Usually torch.fx.passes.split_module.split_module (the underlying algorithm of the graph splitter) makes sure this holds for the submodules it creates, but when the splitter cuts out a node that returns a tuple, the output value will look like return tuple_value. This breaks Inductor.

Example
import torch, thunder

def fn(x): return x.sinc()

@thunder.dynamo.thunderfx
def checkpointed_fn(x):
    return torch.utils.checkpoint.checkpoint(fn, x)

checkpointed_fn(torch.randn(2, 3))
class GraphModule(torch.nn.Module):
    def forward(self, l_x_: "f32[2, 3]"):
        # No stacktrace found for following nodes
        submod_0 = self.submod_0(l_x_);  l_x_ = None
        submod_1 = self.submod_1(submod_0);  submod_0 = None
        return (submod_1,)
        
    class submod_0(torch.nn.Module):
        def forward(self, l_x_: "f32[2, 3]"):
             # File: /opt/pytorch/lightning-thunder/tmp/main.py:8 in checkpointed_fn, code: return torch.utils.checkpoint.checkpoint(fn, x)
            wrap_body_0 = self.wrap_body_0
            tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_);  wrap_body_0 = l_x_ = None
            return tag_activation_checkpoint
            
        class wrap_body_0(torch.nn.Module):
            def forward(self, l_x_: "f32[2, 3]"):
                 # File: /opt/pytorch/lightning-thunder/tmp/main.py:5 in fn, code: def fn(x): return x.sinc()
                sinc: "f32[2, 3]" = l_x_.sinc();  l_x_ = None
                return (sinc,)
                
    class submod_1(torch.nn.Module):
        def forward(self, tag_activation_checkpoint):
             # File: /opt/pytorch/lightning-thunder/tmp/main.py:8 in checkpointed_fn, code: return torch.utils.checkpoint.checkpoint(fn, x)
            getitem: "f32[2, 3]" = tag_activation_checkpoint[0];  tag_activation_checkpoint = None
            return getitem
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 555, in call_and_expect_output_descs
    assert out_spec == out_desc_spec, (
           ^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='<thunder.dynamo.compiler.ThunderCompiler object at 0x7c1a8fe84410>' raised:
AssertionError: ([<function aot_stage1_graph_capture.<locals>.orig_flat_fn2 at 0x7c1a5434ede0>, <function create_functional_call.<locals>.functional_call at 0x7c1a5423ab60>, GraphModule(
  (wrap_body_0): GraphModule()
)], [(FunctionalTensor(_to_functional_tensor(FakeTensor(..., size=(2, 3)))),)], [PlainAOTOutput(idx=0)], TreeSpec(immutable_list, None, [TreeSpec(tuple, None, [*])]), TreeSpec(immutable_list, None, [*]))

This is a complaint for return tag_activation_checkpoint, which compile_fx.make_graph_return_tuple converts to return [tag_activation_checkpoint].

We work around this by grouping the subsequent getitem node together into the node that was cut out. This hack is applied to nodes whose example_value is a tuple, which includes torch.ops.higher_order.tag_activation_checkpoint and torch.ops.higher_order.autograd_function_apply.

(B) Graph break inside Inductor's fallback submodule

Inductor is not responsible for breaking the graph, so it raises an exception when it encounters something it cannot compile.

Example
import torch, thunder.dynamo

@thunder.dynamo.thunderfx
def fn(x, idx, val):
    x = x.clone()
    x[idx] = val
    return x

x = torch.randn(3, requires_grad=True)
idx = x > 0.5
val = torch.randn(torch.count_nonzero(idx), requires_grad=True)

fn(x, idx, val)
class GraphModule(torch.nn.Module):
    def forward(self, l_x_: "f32[3]", l_idx_: "b8[3]", l_val_: "f32[2]"):
        # No stacktrace found for following nodes
        submod_0 = self.submod_0(l_x_);  l_x_ = None
        submod_1 = self.submod_1(submod_0, l_idx_, l_val_);  l_idx_ = l_val_ = submod_1 = None
        return (submod_0,)
        
    class submod_0(torch.nn.Module):
        def forward(self, l_x_: "f32[3]"):
             # File: /opt/pytorch/lightning-thunder/tmp/main.py:7 in fn, code: x = x.clone()
            x: "f32[3]" = l_x_.clone();  l_x_ = None
            return x
            
    class submod_1(torch.nn.Module):
        def forward(self, x: "f32[3]", l_idx_: "b8[3]", l_val_: "f32[2]"):
             # File: /opt/pytorch/lightning-thunder/tmp/main.py:8 in fn, code: x[idx] = val
            x[l_idx_] = l_val_;  setitem = x;  x = l_idx_ = l_val_ = setitem = None
            return ()         
torch._subclasses.fake_tensor.DynamicOutputShapeException: aten.nonzero.default

We must catch the exception and let Dynamo handle it by falling further back to torch.compile. I except there are a lot more exceptions that torch._inductor.compiler may raise. There are no common supertype for such exceptions, so we will have to cover each of them when it pops up.

(C) Lack of example_value

Inductor requires instances of FakeTensors representing the inputs. This is sometimes unavailable: see #2429. We will simply fall back to torch.compile in such cases.

Concern

This PR does not cover cases where a submodule containing torch.autocast() region falls back to torch.compile (when (B) or (C) applies). AssertionError is raised in such cases (see the new test which is xfailed).

Should we catch such errors and resort to eager mode? I wonder how big the impact will be.

@shino16
Copy link
Collaborator Author

shino16 commented Oct 4, 2025

Note on (C) Lack of example_value: Thanks to pytorch/pytorch#163807, example_value no longer becomes unavailable in our reproducer #2429, unless we set torch._dynamo.config.capture_scalar_outputs to True. I am not aware of other possibilities that makes example_value unavailable, but I will keep the mechanisms and the tests for (C) for safety.

@shino16 shino16 marked this pull request as ready for review October 4, 2025 18:49
@shino16
Copy link
Collaborator Author

shino16 commented Oct 6, 2025

Tests are failing on Windows with:

compiler = 'cl'

    @functools.cache
    def check_compiler_exist_windows(compiler: str) -> None:
        """
        Check if compiler is ready, in case end user not activate MSVC environment.
        """
        try:
            subprocess.check_output([compiler, "/help"], stderr=subprocess.STDOUT)
        except FileNotFoundError as exc:
>           raise RuntimeError(f"Compiler: {compiler} is not found.") from exc
E           torch._inductor.exc.InductorError: RuntimeError: Compiler: cl is not found.
E           
E           Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

C:\hostedtoolcache\windows\Python\3.10.11\x64\lib\site-packages\torch\_inductor\cpp_builder.py:140: InductorError

I'm trying to prepare a Windows environment for investigation.

@kiya00
Copy link
Collaborator

kiya00 commented Oct 6, 2025

Hi @shino16 , IIRC, all the torch.compile related tests are skipped on Windows by the IS_WINDOWS flag, #1326 I think the situation remains unchanged.
I'm curious whether this PR could affect performance?

@mattteochen
Copy link
Collaborator

Thank you, @shino16. It seems that this should solve the torch compile GraphModule fallback issues that we were seeing. In theory, this shouldn't introduce perf regression, as before we were running submodules in eager mode, but some benchmarks could be useful.

@shino16
Copy link
Collaborator Author

shino16 commented Oct 6, 2025

Thank you for your comments! As @mattteochen says, we should gain performance benefit from Inductor for fallback submodules, but latency will increase for the same reason.

@kshitij12345
Copy link
Collaborator

but latency will increase for the same reason.

IIUC, this would happen only for the first run, right?

@mattteochen
Copy link
Collaborator

but latency will increase for the same reason.

IIUC, this would happen only for the first run, right?

Yes, torch inductor has its own cache (very random snapshot):
proxy_gm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ThunderFX's fallback is not using Inductor compilation
4 participants