Skip to content

Conversation

shino16
Copy link
Collaborator

@shino16 shino16 commented Sep 29, 2025

Fixes #2539, which in turn fixes #2527 and #2501. This makes PR #2538 obsolete.

We use compile_fx to compile the fallback GraphModule. Unlike bare torch.compile, compile_fx properly lowers the GraphModule as #2527 (comment) points out.

@shino16 shino16 closed this Sep 29, 2025
@shino16 shino16 deleted the inductor-submodule-entrypoint branch September 29, 2025 14:53
@shino16
Copy link
Collaborator Author

shino16 commented Sep 29, 2025

Sorry, I found test failures that can't be fixed with compiler_fx. I'm yet to know why this happens, but we'll need to somehow wrap compiler_fx's return value, and wrapping the input for inductor instead seems like a simpler approach.

____________________________________________________________ test_checkpoint_memory_use[sinc] _____________________________________________________________

op = <built-in method sinc of type object at 0x7d40bc97fbc0>

    @requiresCUDA
    @pytest.mark.parametrize("op", [torch.sin, torch.sinc])
    def test_checkpoint_memory_use(op):
        import torch.utils.checkpoint as checkpoint
    
        def fn(x):
            return op(op(op(op(x))))
    
        def checkpoint_fn(x):
            return checkpoint.checkpoint(fn, x, use_reentrant=False)
    
        initial_mem = torch.cuda.memory_allocated()
    
        x = torch.randn((128, 128), device="cuda", requires_grad=True)
        jfn = thunderfx(checkpoint_fn)
>       y = jfn(x)

thunder/tests/test_dynamo.py:880: 

(snip)

fn = <function aot_stage1_graph_capture.<locals>.orig_flat_fn2 at 0x7d3ece392c00>
args = (FunctionalTensor(_to_functional_tensor(FakeTensor(..., device='cuda:0', size=(128, 128), requires_grad=True),
       device='cuda:0', grad_fn=<Error>)),)

    def call_and_expect_output_descs(fn, args):
        outs_pair = fn(*args)
        assert isinstance(outs_pair, tuple) and len(outs_pair) == 2, (fn, outs_pair)
        outs, outs_descs = outs_pair
        # The Tensor tests protects against the test when there are no outputs
        out_vals, out_spec = pytree.tree_flatten(outs)
        out_desc_vals, out_desc_spec = pytree.tree_flatten(outs_descs)
>       assert out_spec == out_desc_spec, (
            fn_wrappers(fn),
            outs,
            outs_descs,
            out_spec,
            out_desc_spec,
        )
E       torch._dynamo.exc.BackendCompilerFailed: backend='<thunder.dynamo.compiler.ThunderCompiler object at 0x7d3ece010350>' raised:
E       AssertionError: ([<function aot_stage1_graph_capture.<locals>.orig_flat_fn2 at 0x7d3ece392c00>, <function create_functional_call.<locals>.functional_call at 0x7d3ece392b60>, GraphModule(
E         (wrap_body_0): GraphModule()
E       )], [(FunctionalTensor(_to_functional_tensor(FakeTensor(..., device='cuda:0', size=(128, 128)),
E              device='cuda:0')),)], [PlainAOTOutput(idx=0)], TreeSpec(immutable_list, None, [TreeSpec(tuple, None, [*])]), TreeSpec(immutable_list, None, [*]))
E       
E       Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/utils.py:555: BackendCompilerFailed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ThunderFX's fallback is not using Inductor compilation Activation checkpoint not working inside Inductor-compiled submodules
1 participant