From 1bb32071b54f7b4fc766a41b236d967aea6e6dd8 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Mon, 29 Sep 2025 04:51:45 -0700 Subject: [PATCH 1/8] Use _inductor.compile for entrypoint --- thunder/dynamo/compiler.py | 3 +-- thunder/dynamo/report.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index 2ee582bba0..c339aae8e2 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -128,7 +128,6 @@ def __init__(self, **thunder_options): "thunderfx_disable_split_autograd", _DEFAULT_THUNDERFX_DISABLE_SPLIT_AUTOGRAD ) self.thunder_options = thunder_options - self._torch_compile = torch.compile def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]): from thunder import jit @@ -148,7 +147,7 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor split_module, subgraph_info = _splitter( gm, partial(jit, **thunder_options), - self._torch_compile, + torch._inductor.compile, sample_args, ) self.subgraph_infos.append(subgraph_info) diff --git a/thunder/dynamo/report.py b/thunder/dynamo/report.py index 3e9c1b900b..a7f73ed7a7 100644 --- a/thunder/dynamo/report.py +++ b/thunder/dynamo/report.py @@ -1107,7 +1107,7 @@ def foo(x): thunder_options[k] = v thunder_jit = partial(jit, **thunder_options, nv_save_fake_inputs=True) - _, subgraph_info = _splitter(gm, thunder_jit, torch.compile, _unused_sample_args=None) + _, subgraph_info = _splitter(gm, thunder_jit, torch._inductor.compile, _unused_sample_args=None) thunder_module_names = [f"{report.graph_name}_{name}" for name in get_thunder_module_names(subgraph_info)] original_modules_to_thunder_modules = ( From 9cfb006f908b10f2a30741e8f3df0135d0ae1835 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Fri, 3 Oct 2025 13:56:09 -0700 Subject: [PATCH 2/8] Group submodule and getattr of its output together --- thunder/dynamo/splitter.py | 12 +++++++++++- thunder/tests/test_dynamo.py | 3 ++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index 7949937ca5..b3796b6624 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -1,4 +1,5 @@ from __future__ import annotations +import operator from typing import TYPE_CHECKING import copy from functools import partial @@ -96,11 +97,12 @@ def forward(self, l_x_: "f32[2]", y: "f32[2]"): partition_cnt = 0 supported_partitions: set[int] = set() split_reasons: list[SplitReason] = [] + unsupported_collection_users: set[torch.fx.Node] = set() nodes_in_unsupported_ctx_regions = get_nodes_in_unsupported_ctx_regions(gm) def callback(node) -> int: - nonlocal prev_value, partition_cnt, split_reasons, supported_partitions + nonlocal prev_value, partition_cnt, split_reasons, supported_partitions, unsupported_collection_users assert node.op not in ( "placeholder", @@ -118,11 +120,19 @@ def callback(node) -> int: info=f"node with name: {node.name} and target: {node.target} is not supported probably because it is in unsupported context.", ) split_reasons.append(split_reason) + elif node in unsupported_collection_users: + is_thunder_supported = False + split_reason = None else: is_thunder_supported, split_reason = is_node_supported_by_thunder(node) if split_reason is not None: split_reasons.append(split_reason) + if not is_thunder_supported and baseutils.is_collection(node.meta.get("example_value", None)): + for user in node.users: + assert user.target is operator.getitem + unsupported_collection_users.add(user) + if prev_value == is_thunder_supported: # We are in the same region. return partition_cnt diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 1a90e2fe89..1a89af6d13 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -1633,7 +1633,8 @@ def fn(x, idx, val): val = torch.randn(nz, dtype=torch.bfloat16, requires_grad=True) cfn = thunderfx(fn) - cfn(x, idx, val) + with pytest.warns(match="Dynamic output shape operator"): + cfn(x, idx, val) reason = cfn._backend.subgraph_infos[0].split_reasons assert len(reason) == 1 assert "Failed while running meta for node with name: setitem" in reason[0].info From fffc5a45dafd1ac97b748aa56952a4bf7dff763c Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Fri, 3 Oct 2025 14:46:28 -0700 Subject: [PATCH 3/8] Add Inductor fallback --- thunder/dynamo/splitter.py | 31 ++++++++++++++++++++++++++++++- thunder/dynamo/utils.py | 19 +++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index b3796b6624..f8bacac534 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -3,10 +3,13 @@ from typing import TYPE_CHECKING import copy from functools import partial +import warnings import torch +from torch._subclasses.fake_tensor import DynamicOutputShapeException from torch.fx.passes.split_module import split_module +from thunder.core import baseutils from thunder.dynamo.utils import ( SubgraphInfo, CompiledFunction, @@ -18,6 +21,7 @@ update_node_and_submodule, recompile_graph, checkpoint_converter, + make_fake_arguments, _get_example_inputs_from_placeholder, _ThunderSplitGraphModule, ) @@ -208,7 +212,32 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: ) elif node.name.startswith("submod"): # For inductor graph_module = getattr(split_gm, node.name) - jit_fn = torch_inductor(graph_module) + + class ModuleWrapper(torch.nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, *args, **kwargs): + return self.fn(*args, **kwargs) + + def fallback_torch_compile(reason: str) -> torch.nn.Module: + warnings.warn(f"{reason} Falling back to torch.compile.") + # torch.compile does not lower GraphModule properly. See https://github.com/Lightning-AI/lightning-thunder/issues/2539 + # We work around this by wrapping it in a Module + return torch.compile(ModuleWrapper(graph_module)) + + fake_args = make_fake_arguments(graph_module) + if fake_args is None: + jit_fn = fallback_torch_compile("Example values for arguments are not available.") + else: + try: + # torch._inductor.compile returns a function, but update_node_and_submodule expects a Module + jit_fn = ModuleWrapper(torch_inductor(graph_module, fake_args)) + except DynamicOutputShapeException as e: + # This exception is meant to be handled by Dynamo, which is responsible for graph break + jit_fn = fallback_torch_compile(f"Dynamic output shape operator encountered: {e}.") + # Update the node name from "submod_*" to "inductor_*" for more user-friendly names update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn) submodule_to_compiled_fns[getattr(original_split_gm, node_name)] = CompiledFunction( diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index 3c58bc2ded..178681b7f0 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -12,6 +12,8 @@ import torch from torch.nn.modules.module import _addindent from torch.utils.weak import TensorWeakRef +from torch._guards import detect_fake_mode +from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor if torch.distributed.is_available(): from torch.distributed.tensor import DTensor @@ -1050,3 +1052,20 @@ def get_compiled_fn_and_timing(report, compile_fn, timer_fn): err_msg = ", ".join([f"{x.name} raised exception: {x.compiled_fn}" for x in sorted_compiled_gm_to_measurement]) raise RuntimeError(f"No compiler was able to compile the graph module, {err_msg}") return sorted_compiled_gm_to_measurement[0].compiled_fn + + +def make_fake_arguments(gm: torch.fx.GraphModule) -> list[FakeTensor] | None: + fake_mode = detect_fake_mode() + if fake_mode is None: + fake_mode = FakeTensorMode() + args = [] + for node in gm.graph.nodes: + if node.op == "placeholder": + meta_val = node.meta.get("example_value") + if meta_val is None: + return None + if isinstance(meta_val, torch.Tensor): + # Tie to the currently enabled fake mode + meta_val = fake_mode.fake_tensor_converter.from_real_tensor(fake_mode, meta_val) + args.append(meta_val) + return args From 2f8c14371d495be8765693d919fc8105c2606306 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Thu, 25 Sep 2025 07:55:08 -0700 Subject: [PATCH 4/8] Add test --- thunder/tests/test_dynamo.py | 68 +++++++++++++++++++++++++++++++++--- 1 file changed, 64 insertions(+), 4 deletions(-) diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 1a89af6d13..aab6eb5f97 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -155,6 +155,21 @@ def func(x): assert any(target.startswith("thunder_") for target in targets) # Verify that the submodules have name `thunder_*` +@instantiate(dtypes=NOTHING) +def test_fallback_to_inductor(executor, device, dtype): + x = torch.randn(3, 3, device=device, dtype=dtype) + + def func(x): + return x.sinc().cos().sinc().sinc() + + cfunc = thunderfx(func) + with patch("torch._inductor.compile", side_effect=torch._inductor.compile) as mock_inductor: + cfunc(x) + + # Once for sinc() and once for sinc().sinc() + assert mock_inductor.call_count == 2 + + @instantiate( dtypes=NOTHING, executors=[DynamoThunderExecutor], @@ -844,6 +859,35 @@ def find_target_module(model, target_module_name): assert isinstance(n.target, Symbol) or callable(n.target) +@requiresCUDA +@pytest.mark.parametrize("op", [torch.sin, torch.sinc]) +def test_checkpoint_memory_use(op): + import torch.utils.checkpoint as checkpoint + + def fn(x): + return op(op(op(op(x)))) + + def checkpoint_fn(x): + return checkpoint.checkpoint(fn, x, use_reentrant=False) + + initial_mem = torch.cuda.memory_allocated() + + x = torch.randn((128, 128), device="cuda", requires_grad=True) + jfn = thunderfx(checkpoint_fn) + y = jfn(x) + + peak_mem_usage = torch.cuda.max_memory_allocated() - initial_mem + + y_ref = fn(x) + torch.testing.assert_close(y, y_ref) + + assert peak_mem_usage == x.nbytes * 2 + if op == torch.sinc: + # Make sure the checkpointed region falled back to PyTorch + sinfo = jfn._backend.subgraph_infos[-1] + assert any(n.name.startswith("inductor") for n in sinfo.split_graph_module.graph.nodes) + + @instantiate( dtypes=NOTHING, executors=[DynamoThunderExecutor], @@ -935,6 +979,8 @@ def forward(self, x): def test_deepcopy_graph_module(): + import thunder + class MyModule(torch.nn.Module): def __init__(self): super().__init__() @@ -946,9 +992,8 @@ def forward(self, x): gm = torch.fx.symbolic_trace(m) n = gm.graph.find_nodes(op="output") gm.graph.erase_node(n[0]) - import thunder - _, subgraph_info = thunder.dynamo.splitter._splitter(gm, thunder.jit, thunder.jit, []) + _, subgraph_info = thunder.dynamo.splitter._splitter(gm, thunder.jit, torch._inductor.compile, []) original_split_gm = subgraph_info.original_split_graph_module.split_graph_module assert original_split_gm.graph.find_nodes(op="output") for subm in original_split_gm.children(): @@ -1621,7 +1666,7 @@ def foo(x): pfoo(x) -def test_spliter_bwd(): +def test_splitter_bwd(): def fn(x, idx, val): x = x.clone() x[idx] = val @@ -1720,11 +1765,26 @@ def test_fn(x): return z + 2 x = torch.tensor([1, 2, 3, 4, 5]) - actual = thunderfx(test_fn)(x) + with pytest.warns(match="Example values for arguments are not available"): + actual = thunderfx(test_fn)(x) expected = test_fn(x) torch.testing.assert_close(actual, expected) +@pytest.mark.xfail(reason="Unsupported: Dynamo can't retrace autocast enter/exit", raises=AssertionError) +def test_thunderfx_no_example_value_and_autocast(): + def fn(x): + with torch.autocast("cpu"): + y = x + 10 + z = y.tolist()[0] + return z + 2 + + x = torch.tensor([1, 2, 3, 4, 5]) + actual = thunderfx(fn)(x) + expected = fn(x) + torch.testing.assert_close(actual, expected) + + # This test addresses the bug reported in https://github.com/Lightning-AI/lightning-thunder/issues/2398 def test_no_grad_region_split(): def fn(x): From 6b00ae7ad2011cf52db363c12a767cce5208f031 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Fri, 3 Oct 2025 15:12:26 -0700 Subject: [PATCH 5/8] Add .graph attribute for visibility in print_readable --- thunder/dynamo/splitter.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index f8bacac534..969cc79266 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -238,6 +238,9 @@ def fallback_torch_compile(reason: str) -> torch.nn.Module: # This exception is meant to be handled by Dynamo, which is responsible for graph break jit_fn = fallback_torch_compile(f"Dynamic output shape operator encountered: {e}.") + # This is for ease of debugging. We add graph attribute so GraphModule.print_readable will print it + jit_fn.graph = graph_module.graph + # Update the node name from "submod_*" to "inductor_*" for more user-friendly names update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn) submodule_to_compiled_fns[getattr(original_split_gm, node_name)] = CompiledFunction( From 1a3214aa030d9ea72fb8bc204ccfac11820db2c4 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Sat, 4 Oct 2025 11:27:52 -0700 Subject: [PATCH 6/8] Address torch 2.10, which breaks graph when example_value is unavailable --- thunder/tests/test_dynamo.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index aab6eb5f97..0dd9ee6c1d 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -1765,13 +1765,20 @@ def test_fn(x): return z + 2 x = torch.tensor([1, 2, 3, 4, 5]) - with pytest.warns(match="Example values for arguments are not available"): + if LooseVersion(torch.__version__) < "2.10": + with pytest.warns(match="Example values for arguments are not available"): + actual = thunderfx(test_fn)(x) + else: actual = thunderfx(test_fn)(x) expected = test_fn(x) torch.testing.assert_close(actual, expected) -@pytest.mark.xfail(reason="Unsupported: Dynamo can't retrace autocast enter/exit", raises=AssertionError) +@pytest.mark.xfail( + condition=LooseVersion(torch.__version__) < "2.10", + reason="Unsupported: Dynamo can't retrace autocast enter/exit", + raises=AssertionError, +) def test_thunderfx_no_example_value_and_autocast(): def fn(x): with torch.autocast("cpu"): From 11bb431b27b9f3011fc81b5df58914a42a7f244c Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Sat, 4 Oct 2025 11:40:39 -0700 Subject: [PATCH 7/8] Patch dynamo config to reproduce example_value being unavailable --- thunder/tests/test_dynamo.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 0dd9ee6c1d..01a615ff85 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -1765,20 +1765,15 @@ def test_fn(x): return z + 2 x = torch.tensor([1, 2, 3, 4, 5]) - if LooseVersion(torch.__version__) < "2.10": + # Without this patch, tolist() would cause graph break. See https://github.com/pytorch/pytorch/pull/163807 + with patch("torch._dynamo.config.capture_scalar_outputs", True): with pytest.warns(match="Example values for arguments are not available"): actual = thunderfx(test_fn)(x) - else: - actual = thunderfx(test_fn)(x) expected = test_fn(x) torch.testing.assert_close(actual, expected) -@pytest.mark.xfail( - condition=LooseVersion(torch.__version__) < "2.10", - reason="Unsupported: Dynamo can't retrace autocast enter/exit", - raises=AssertionError, -) +@pytest.mark.xfail(reason="Unsupported: Dynamo can't retrace autocast enter/exit", raises=AssertionError) def test_thunderfx_no_example_value_and_autocast(): def fn(x): with torch.autocast("cpu"): @@ -1787,7 +1782,10 @@ def fn(x): return z + 2 x = torch.tensor([1, 2, 3, 4, 5]) - actual = thunderfx(fn)(x) + # Without this patch, tolist() would cause graph break. See https://github.com/pytorch/pytorch/pull/163807 + with patch("torch._dynamo.config.capture_scalar_outputs", True): + with pytest.warns(match="Example values for arguments are not available"): + actual = thunderfx(fn)(x) expected = fn(x) torch.testing.assert_close(actual, expected) From cef360e254a1ae169b14ffbf745b46a99a6cdb8c Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Mon, 6 Oct 2025 02:45:14 -0700 Subject: [PATCH 8/8] Skip some thunderfx tests on Windows --- thunder/tests/test_dynamo.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 01a615ff85..329027eac1 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -155,7 +155,15 @@ def func(x): assert any(target.startswith("thunder_") for target in targets) # Verify that the submodules have name `thunder_*` -@instantiate(dtypes=NOTHING) +@instantiate( + dtypes=NOTHING, + decorators=( + pytest.mark.skipif( + condition=IS_WINDOWS, + reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094", + ), + ), +) def test_fallback_to_inductor(executor, device, dtype): x = torch.randn(3, 3, device=device, dtype=dtype) @@ -1007,7 +1015,13 @@ def forward(self, x): @instantiate( dtypes=NOTHING, executors=[DynamoThunderExecutor], - decorators=(pytest.mark.parametrize("use_pytest_benchmark", (True, False), ids=("benchmark", "repro")),), + decorators=( + pytest.mark.parametrize("use_pytest_benchmark", (True, False), ids=("benchmark", "repro")), + pytest.mark.skipif( + condition=IS_WINDOWS, + reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094", + ), + ), ) @given(file_indices=st.lists(st.integers(min_value=0, max_value=2), min_size=2, max_size=2, unique=True)) @settings(max_examples=1, deadline=None) @@ -1276,6 +1290,10 @@ def foo(x, y): run_script(file, cmd) +@pytest.mark.skipif( + condition=IS_WINDOWS, + reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094", +) def test_leak_on_unsupported_thunder_operator(): # This test is to check the fix for a previous leak # which was caused by holding onto the @@ -1791,6 +1809,10 @@ def fn(x): # This test addresses the bug reported in https://github.com/Lightning-AI/lightning-thunder/issues/2398 +@pytest.mark.skipif( + condition=IS_WINDOWS, + reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094", +) def test_no_grad_region_split(): def fn(x): # Thunder supports enclosing torch.set_grad_enabled(False/True)