diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index 2ee582bba0..c339aae8e2 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -128,7 +128,6 @@ def __init__(self, **thunder_options): "thunderfx_disable_split_autograd", _DEFAULT_THUNDERFX_DISABLE_SPLIT_AUTOGRAD ) self.thunder_options = thunder_options - self._torch_compile = torch.compile def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]): from thunder import jit @@ -148,7 +147,7 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor split_module, subgraph_info = _splitter( gm, partial(jit, **thunder_options), - self._torch_compile, + torch._inductor.compile, sample_args, ) self.subgraph_infos.append(subgraph_info) diff --git a/thunder/dynamo/report.py b/thunder/dynamo/report.py index 3e9c1b900b..a7f73ed7a7 100644 --- a/thunder/dynamo/report.py +++ b/thunder/dynamo/report.py @@ -1107,7 +1107,7 @@ def foo(x): thunder_options[k] = v thunder_jit = partial(jit, **thunder_options, nv_save_fake_inputs=True) - _, subgraph_info = _splitter(gm, thunder_jit, torch.compile, _unused_sample_args=None) + _, subgraph_info = _splitter(gm, thunder_jit, torch._inductor.compile, _unused_sample_args=None) thunder_module_names = [f"{report.graph_name}_{name}" for name in get_thunder_module_names(subgraph_info)] original_modules_to_thunder_modules = ( diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index 7949937ca5..969cc79266 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -1,11 +1,15 @@ from __future__ import annotations +import operator from typing import TYPE_CHECKING import copy from functools import partial +import warnings import torch +from torch._subclasses.fake_tensor import DynamicOutputShapeException from torch.fx.passes.split_module import split_module +from thunder.core import baseutils from thunder.dynamo.utils import ( SubgraphInfo, CompiledFunction, @@ -17,6 +21,7 @@ update_node_and_submodule, recompile_graph, checkpoint_converter, + make_fake_arguments, _get_example_inputs_from_placeholder, _ThunderSplitGraphModule, ) @@ -96,11 +101,12 @@ def forward(self, l_x_: "f32[2]", y: "f32[2]"): partition_cnt = 0 supported_partitions: set[int] = set() split_reasons: list[SplitReason] = [] + unsupported_collection_users: set[torch.fx.Node] = set() nodes_in_unsupported_ctx_regions = get_nodes_in_unsupported_ctx_regions(gm) def callback(node) -> int: - nonlocal prev_value, partition_cnt, split_reasons, supported_partitions + nonlocal prev_value, partition_cnt, split_reasons, supported_partitions, unsupported_collection_users assert node.op not in ( "placeholder", @@ -118,11 +124,19 @@ def callback(node) -> int: info=f"node with name: {node.name} and target: {node.target} is not supported probably because it is in unsupported context.", ) split_reasons.append(split_reason) + elif node in unsupported_collection_users: + is_thunder_supported = False + split_reason = None else: is_thunder_supported, split_reason = is_node_supported_by_thunder(node) if split_reason is not None: split_reasons.append(split_reason) + if not is_thunder_supported and baseutils.is_collection(node.meta.get("example_value", None)): + for user in node.users: + assert user.target is operator.getitem + unsupported_collection_users.add(user) + if prev_value == is_thunder_supported: # We are in the same region. return partition_cnt @@ -198,7 +212,35 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: ) elif node.name.startswith("submod"): # For inductor graph_module = getattr(split_gm, node.name) - jit_fn = torch_inductor(graph_module) + + class ModuleWrapper(torch.nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, *args, **kwargs): + return self.fn(*args, **kwargs) + + def fallback_torch_compile(reason: str) -> torch.nn.Module: + warnings.warn(f"{reason} Falling back to torch.compile.") + # torch.compile does not lower GraphModule properly. See https://github.com/Lightning-AI/lightning-thunder/issues/2539 + # We work around this by wrapping it in a Module + return torch.compile(ModuleWrapper(graph_module)) + + fake_args = make_fake_arguments(graph_module) + if fake_args is None: + jit_fn = fallback_torch_compile("Example values for arguments are not available.") + else: + try: + # torch._inductor.compile returns a function, but update_node_and_submodule expects a Module + jit_fn = ModuleWrapper(torch_inductor(graph_module, fake_args)) + except DynamicOutputShapeException as e: + # This exception is meant to be handled by Dynamo, which is responsible for graph break + jit_fn = fallback_torch_compile(f"Dynamic output shape operator encountered: {e}.") + + # This is for ease of debugging. We add graph attribute so GraphModule.print_readable will print it + jit_fn.graph = graph_module.graph + # Update the node name from "submod_*" to "inductor_*" for more user-friendly names update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn) submodule_to_compiled_fns[getattr(original_split_gm, node_name)] = CompiledFunction( diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index 3c58bc2ded..178681b7f0 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -12,6 +12,8 @@ import torch from torch.nn.modules.module import _addindent from torch.utils.weak import TensorWeakRef +from torch._guards import detect_fake_mode +from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor if torch.distributed.is_available(): from torch.distributed.tensor import DTensor @@ -1050,3 +1052,20 @@ def get_compiled_fn_and_timing(report, compile_fn, timer_fn): err_msg = ", ".join([f"{x.name} raised exception: {x.compiled_fn}" for x in sorted_compiled_gm_to_measurement]) raise RuntimeError(f"No compiler was able to compile the graph module, {err_msg}") return sorted_compiled_gm_to_measurement[0].compiled_fn + + +def make_fake_arguments(gm: torch.fx.GraphModule) -> list[FakeTensor] | None: + fake_mode = detect_fake_mode() + if fake_mode is None: + fake_mode = FakeTensorMode() + args = [] + for node in gm.graph.nodes: + if node.op == "placeholder": + meta_val = node.meta.get("example_value") + if meta_val is None: + return None + if isinstance(meta_val, torch.Tensor): + # Tie to the currently enabled fake mode + meta_val = fake_mode.fake_tensor_converter.from_real_tensor(fake_mode, meta_val) + args.append(meta_val) + return args diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 1a90e2fe89..329027eac1 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -155,6 +155,29 @@ def func(x): assert any(target.startswith("thunder_") for target in targets) # Verify that the submodules have name `thunder_*` +@instantiate( + dtypes=NOTHING, + decorators=( + pytest.mark.skipif( + condition=IS_WINDOWS, + reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094", + ), + ), +) +def test_fallback_to_inductor(executor, device, dtype): + x = torch.randn(3, 3, device=device, dtype=dtype) + + def func(x): + return x.sinc().cos().sinc().sinc() + + cfunc = thunderfx(func) + with patch("torch._inductor.compile", side_effect=torch._inductor.compile) as mock_inductor: + cfunc(x) + + # Once for sinc() and once for sinc().sinc() + assert mock_inductor.call_count == 2 + + @instantiate( dtypes=NOTHING, executors=[DynamoThunderExecutor], @@ -844,6 +867,35 @@ def find_target_module(model, target_module_name): assert isinstance(n.target, Symbol) or callable(n.target) +@requiresCUDA +@pytest.mark.parametrize("op", [torch.sin, torch.sinc]) +def test_checkpoint_memory_use(op): + import torch.utils.checkpoint as checkpoint + + def fn(x): + return op(op(op(op(x)))) + + def checkpoint_fn(x): + return checkpoint.checkpoint(fn, x, use_reentrant=False) + + initial_mem = torch.cuda.memory_allocated() + + x = torch.randn((128, 128), device="cuda", requires_grad=True) + jfn = thunderfx(checkpoint_fn) + y = jfn(x) + + peak_mem_usage = torch.cuda.max_memory_allocated() - initial_mem + + y_ref = fn(x) + torch.testing.assert_close(y, y_ref) + + assert peak_mem_usage == x.nbytes * 2 + if op == torch.sinc: + # Make sure the checkpointed region falled back to PyTorch + sinfo = jfn._backend.subgraph_infos[-1] + assert any(n.name.startswith("inductor") for n in sinfo.split_graph_module.graph.nodes) + + @instantiate( dtypes=NOTHING, executors=[DynamoThunderExecutor], @@ -935,6 +987,8 @@ def forward(self, x): def test_deepcopy_graph_module(): + import thunder + class MyModule(torch.nn.Module): def __init__(self): super().__init__() @@ -946,9 +1000,8 @@ def forward(self, x): gm = torch.fx.symbolic_trace(m) n = gm.graph.find_nodes(op="output") gm.graph.erase_node(n[0]) - import thunder - _, subgraph_info = thunder.dynamo.splitter._splitter(gm, thunder.jit, thunder.jit, []) + _, subgraph_info = thunder.dynamo.splitter._splitter(gm, thunder.jit, torch._inductor.compile, []) original_split_gm = subgraph_info.original_split_graph_module.split_graph_module assert original_split_gm.graph.find_nodes(op="output") for subm in original_split_gm.children(): @@ -962,7 +1015,13 @@ def forward(self, x): @instantiate( dtypes=NOTHING, executors=[DynamoThunderExecutor], - decorators=(pytest.mark.parametrize("use_pytest_benchmark", (True, False), ids=("benchmark", "repro")),), + decorators=( + pytest.mark.parametrize("use_pytest_benchmark", (True, False), ids=("benchmark", "repro")), + pytest.mark.skipif( + condition=IS_WINDOWS, + reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094", + ), + ), ) @given(file_indices=st.lists(st.integers(min_value=0, max_value=2), min_size=2, max_size=2, unique=True)) @settings(max_examples=1, deadline=None) @@ -1231,6 +1290,10 @@ def foo(x, y): run_script(file, cmd) +@pytest.mark.skipif( + condition=IS_WINDOWS, + reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094", +) def test_leak_on_unsupported_thunder_operator(): # This test is to check the fix for a previous leak # which was caused by holding onto the @@ -1621,7 +1684,7 @@ def foo(x): pfoo(x) -def test_spliter_bwd(): +def test_splitter_bwd(): def fn(x, idx, val): x = x.clone() x[idx] = val @@ -1633,7 +1696,8 @@ def fn(x, idx, val): val = torch.randn(nz, dtype=torch.bfloat16, requires_grad=True) cfn = thunderfx(fn) - cfn(x, idx, val) + with pytest.warns(match="Dynamic output shape operator"): + cfn(x, idx, val) reason = cfn._backend.subgraph_infos[0].split_reasons assert len(reason) == 1 assert "Failed while running meta for node with name: setitem" in reason[0].info @@ -1719,12 +1783,36 @@ def test_fn(x): return z + 2 x = torch.tensor([1, 2, 3, 4, 5]) - actual = thunderfx(test_fn)(x) + # Without this patch, tolist() would cause graph break. See https://github.com/pytorch/pytorch/pull/163807 + with patch("torch._dynamo.config.capture_scalar_outputs", True): + with pytest.warns(match="Example values for arguments are not available"): + actual = thunderfx(test_fn)(x) expected = test_fn(x) torch.testing.assert_close(actual, expected) +@pytest.mark.xfail(reason="Unsupported: Dynamo can't retrace autocast enter/exit", raises=AssertionError) +def test_thunderfx_no_example_value_and_autocast(): + def fn(x): + with torch.autocast("cpu"): + y = x + 10 + z = y.tolist()[0] + return z + 2 + + x = torch.tensor([1, 2, 3, 4, 5]) + # Without this patch, tolist() would cause graph break. See https://github.com/pytorch/pytorch/pull/163807 + with patch("torch._dynamo.config.capture_scalar_outputs", True): + with pytest.warns(match="Example values for arguments are not available"): + actual = thunderfx(fn)(x) + expected = fn(x) + torch.testing.assert_close(actual, expected) + + # This test addresses the bug reported in https://github.com/Lightning-AI/lightning-thunder/issues/2398 +@pytest.mark.skipif( + condition=IS_WINDOWS, + reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094", +) def test_no_grad_region_split(): def fn(x): # Thunder supports enclosing torch.set_grad_enabled(False/True)