diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index 7949937ca5..51f1bcf28a 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -198,7 +198,18 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: ) elif node.name.startswith("submod"): # For inductor graph_module = getattr(split_gm, node.name) - jit_fn = torch_inductor(graph_module) + + class Wrapped(torch.nn.Module): + def __init__(self, gm): + super().__init__() + self.gm = gm + + def forward(self, *a): + return self.gm(*a) + + # Make sure Inductor does not skip graph_module's compilation by wrapping it + # See https://github.com/Lightning-AI/lightning-thunder/issues/2527#issuecomment-3345877210 + jit_fn = torch_inductor(Wrapped(graph_module)) # Update the node name from "submod_*" to "inductor_*" for more user-friendly names update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn) submodule_to_compiled_fns[getattr(original_split_gm, node_name)] = CompiledFunction( diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 1a90e2fe89..5cc191b54a 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -155,6 +155,24 @@ def func(x): assert any(target.startswith("thunder_") for target in targets) # Verify that the submodules have name `thunder_*` +@instantiate(dtypes=NOTHING) +def test_inductor_fallback(executor, device, dtype): + x = torch.randn(3, 3, device=device, dtype=dtype) + + def func(x): + return x.sinc().cos().sinc().sinc() + + def trivial_compile(model, *args, **kwargs): + return model + + cfunc = thunderfx(func) + with patch("torch._inductor.compile_fx.compile_fx", side_effect=trivial_compile) as mock_call: + cfunc(x) + + # Once for sinc() and once for sinc().sinc() + assert mock_call.call_count == 2 + + @instantiate( dtypes=NOTHING, executors=[DynamoThunderExecutor], @@ -844,6 +862,35 @@ def find_target_module(model, target_module_name): assert isinstance(n.target, Symbol) or callable(n.target) +@requiresCUDA +@pytest.mark.parametrize("op", [torch.sin, torch.sinc]) +def test_checkpoint_memory_use(op): + import torch.utils.checkpoint as checkpoint + + def fn(x): + return op(op(op(op(x)))) + + def checkpoint_fn(x): + return checkpoint.checkpoint(fn, x, use_reentrant=False) + + initial_mem = torch.cuda.memory_allocated() + + x = torch.randn((128, 128), device="cuda", requires_grad=True) + jfn = thunderfx(checkpoint_fn) + y = jfn(x) + + peak_mem_usage = torch.cuda.max_memory_allocated() - initial_mem + + y_ref = fn(x) + torch.testing.assert_close(y, y_ref) + + assert peak_mem_usage == x.nbytes * 2 + if op == torch.sinc: + # Make sure the checkpointed region fell back to PyTorch + sinfo = jfn._backend.subgraph_infos[-1] + assert any(n.name.startswith("inductor") for n in sinfo.split_graph_module.graph.nodes) + + @instantiate( dtypes=NOTHING, executors=[DynamoThunderExecutor],