Skip to content
3 changes: 1 addition & 2 deletions thunder/dynamo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def __init__(self, **thunder_options):
"thunderfx_disable_split_autograd", _DEFAULT_THUNDERFX_DISABLE_SPLIT_AUTOGRAD
)
self.thunder_options = thunder_options
self._torch_compile = torch.compile

def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]):
from thunder import jit
Expand All @@ -148,7 +147,7 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor
split_module, subgraph_info = _splitter(
gm,
partial(jit, **thunder_options),
self._torch_compile,
torch._inductor.compile,
sample_args,
)
self.subgraph_infos.append(subgraph_info)
Expand Down
2 changes: 1 addition & 1 deletion thunder/dynamo/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,7 @@ def foo(x):
thunder_options[k] = v

thunder_jit = partial(jit, **thunder_options, nv_save_fake_inputs=True)
_, subgraph_info = _splitter(gm, thunder_jit, torch.compile, _unused_sample_args=None)
_, subgraph_info = _splitter(gm, thunder_jit, torch._inductor.compile, _unused_sample_args=None)

thunder_module_names = [f"{report.graph_name}_{name}" for name in get_thunder_module_names(subgraph_info)]
original_modules_to_thunder_modules = (
Expand Down
46 changes: 44 additions & 2 deletions thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from __future__ import annotations
import operator
from typing import TYPE_CHECKING
import copy
from functools import partial
import warnings

import torch
from torch._subclasses.fake_tensor import DynamicOutputShapeException
from torch.fx.passes.split_module import split_module

from thunder.core import baseutils
from thunder.dynamo.utils import (
SubgraphInfo,
CompiledFunction,
Expand All @@ -17,6 +21,7 @@
update_node_and_submodule,
recompile_graph,
checkpoint_converter,
make_fake_arguments,
_get_example_inputs_from_placeholder,
_ThunderSplitGraphModule,
)
Expand Down Expand Up @@ -96,11 +101,12 @@ def forward(self, l_x_: "f32[2]", y: "f32[2]"):
partition_cnt = 0
supported_partitions: set[int] = set()
split_reasons: list[SplitReason] = []
unsupported_collection_users: set[torch.fx.Node] = set()

nodes_in_unsupported_ctx_regions = get_nodes_in_unsupported_ctx_regions(gm)

def callback(node) -> int:
nonlocal prev_value, partition_cnt, split_reasons, supported_partitions
nonlocal prev_value, partition_cnt, split_reasons, supported_partitions, unsupported_collection_users

assert node.op not in (
"placeholder",
Expand All @@ -118,11 +124,19 @@ def callback(node) -> int:
info=f"node with name: {node.name} and target: {node.target} is not supported probably because it is in unsupported context.",
)
split_reasons.append(split_reason)
elif node in unsupported_collection_users:
is_thunder_supported = False
split_reason = None
else:
is_thunder_supported, split_reason = is_node_supported_by_thunder(node)
if split_reason is not None:
split_reasons.append(split_reason)

if not is_thunder_supported and baseutils.is_collection(node.meta.get("example_value", None)):
for user in node.users:
assert user.target is operator.getitem
unsupported_collection_users.add(user)

if prev_value == is_thunder_supported: # We are in the same region.
return partition_cnt

Expand Down Expand Up @@ -198,7 +212,35 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
)
elif node.name.startswith("submod"): # For inductor
graph_module = getattr(split_gm, node.name)
jit_fn = torch_inductor(graph_module)

class ModuleWrapper(torch.nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn

def forward(self, *args, **kwargs):
return self.fn(*args, **kwargs)

def fallback_torch_compile(reason: str) -> torch.nn.Module:
warnings.warn(f"{reason} Falling back to torch.compile.")
# torch.compile does not lower GraphModule properly. See https://github.com/Lightning-AI/lightning-thunder/issues/2539
# We work around this by wrapping it in a Module
return torch.compile(ModuleWrapper(graph_module))

fake_args = make_fake_arguments(graph_module)
if fake_args is None:
jit_fn = fallback_torch_compile("Example values for arguments are not available.")
else:
try:
# torch._inductor.compile returns a function, but update_node_and_submodule expects a Module
jit_fn = ModuleWrapper(torch_inductor(graph_module, fake_args))
except DynamicOutputShapeException as e:
# This exception is meant to be handled by Dynamo, which is responsible for graph break
jit_fn = fallback_torch_compile(f"Dynamic output shape operator encountered: {e}.")

# This is for ease of debugging. We add graph attribute so GraphModule.print_readable will print it
jit_fn.graph = graph_module.graph

# Update the node name from "submod_*" to "inductor_*" for more user-friendly names
update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn)
submodule_to_compiled_fns[getattr(original_split_gm, node_name)] = CompiledFunction(
Expand Down
19 changes: 19 additions & 0 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import torch
from torch.nn.modules.module import _addindent
from torch.utils.weak import TensorWeakRef
from torch._guards import detect_fake_mode
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor

if torch.distributed.is_available():
from torch.distributed.tensor import DTensor
Expand Down Expand Up @@ -1050,3 +1052,20 @@ def get_compiled_fn_and_timing(report, compile_fn, timer_fn):
err_msg = ", ".join([f"{x.name} raised exception: {x.compiled_fn}" for x in sorted_compiled_gm_to_measurement])
raise RuntimeError(f"No compiler was able to compile the graph module, {err_msg}")
return sorted_compiled_gm_to_measurement[0].compiled_fn


def make_fake_arguments(gm: torch.fx.GraphModule) -> list[FakeTensor] | None:
fake_mode = detect_fake_mode()
if fake_mode is None:
fake_mode = FakeTensorMode()
args = []
for node in gm.graph.nodes:
if node.op == "placeholder":
meta_val = node.meta.get("example_value")
if meta_val is None:
return None
if isinstance(meta_val, torch.Tensor):
# Tie to the currently enabled fake mode
meta_val = fake_mode.fake_tensor_converter.from_real_tensor(fake_mode, meta_val)
args.append(meta_val)
return args
100 changes: 94 additions & 6 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,29 @@ def func(x):
assert any(target.startswith("thunder_") for target in targets) # Verify that the submodules have name `thunder_*`


@instantiate(
dtypes=NOTHING,
decorators=(
pytest.mark.skipif(
condition=IS_WINDOWS,
reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094",
),
),
)
def test_fallback_to_inductor(executor, device, dtype):
x = torch.randn(3, 3, device=device, dtype=dtype)

def func(x):
return x.sinc().cos().sinc().sinc()

cfunc = thunderfx(func)
with patch("torch._inductor.compile", side_effect=torch._inductor.compile) as mock_inductor:
cfunc(x)

# Once for sinc() and once for sinc().sinc()
assert mock_inductor.call_count == 2


@instantiate(
dtypes=NOTHING,
executors=[DynamoThunderExecutor],
Expand Down Expand Up @@ -844,6 +867,35 @@ def find_target_module(model, target_module_name):
assert isinstance(n.target, Symbol) or callable(n.target)


@requiresCUDA
@pytest.mark.parametrize("op", [torch.sin, torch.sinc])
def test_checkpoint_memory_use(op):
import torch.utils.checkpoint as checkpoint

def fn(x):
return op(op(op(op(x))))

def checkpoint_fn(x):
return checkpoint.checkpoint(fn, x, use_reentrant=False)

initial_mem = torch.cuda.memory_allocated()

x = torch.randn((128, 128), device="cuda", requires_grad=True)
jfn = thunderfx(checkpoint_fn)
y = jfn(x)

peak_mem_usage = torch.cuda.max_memory_allocated() - initial_mem

y_ref = fn(x)
torch.testing.assert_close(y, y_ref)

assert peak_mem_usage == x.nbytes * 2
if op == torch.sinc:
# Make sure the checkpointed region falled back to PyTorch
sinfo = jfn._backend.subgraph_infos[-1]
assert any(n.name.startswith("inductor") for n in sinfo.split_graph_module.graph.nodes)


@instantiate(
dtypes=NOTHING,
executors=[DynamoThunderExecutor],
Expand Down Expand Up @@ -935,6 +987,8 @@ def forward(self, x):


def test_deepcopy_graph_module():
import thunder

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -946,9 +1000,8 @@ def forward(self, x):
gm = torch.fx.symbolic_trace(m)
n = gm.graph.find_nodes(op="output")
gm.graph.erase_node(n[0])
import thunder

_, subgraph_info = thunder.dynamo.splitter._splitter(gm, thunder.jit, thunder.jit, [])
_, subgraph_info = thunder.dynamo.splitter._splitter(gm, thunder.jit, torch._inductor.compile, [])
original_split_gm = subgraph_info.original_split_graph_module.split_graph_module
assert original_split_gm.graph.find_nodes(op="output")
for subm in original_split_gm.children():
Expand All @@ -962,7 +1015,13 @@ def forward(self, x):
@instantiate(
dtypes=NOTHING,
executors=[DynamoThunderExecutor],
decorators=(pytest.mark.parametrize("use_pytest_benchmark", (True, False), ids=("benchmark", "repro")),),
decorators=(
pytest.mark.parametrize("use_pytest_benchmark", (True, False), ids=("benchmark", "repro")),
pytest.mark.skipif(
condition=IS_WINDOWS,
reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094",
),
),
)
@given(file_indices=st.lists(st.integers(min_value=0, max_value=2), min_size=2, max_size=2, unique=True))
@settings(max_examples=1, deadline=None)
Expand Down Expand Up @@ -1231,6 +1290,10 @@ def foo(x, y):
run_script(file, cmd)


@pytest.mark.skipif(
condition=IS_WINDOWS,
reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094",
)
def test_leak_on_unsupported_thunder_operator():
# This test is to check the fix for a previous leak
# which was caused by holding onto the
Expand Down Expand Up @@ -1621,7 +1684,7 @@ def foo(x):
pfoo(x)


def test_spliter_bwd():
def test_splitter_bwd():
def fn(x, idx, val):
x = x.clone()
x[idx] = val
Expand All @@ -1633,7 +1696,8 @@ def fn(x, idx, val):
val = torch.randn(nz, dtype=torch.bfloat16, requires_grad=True)

cfn = thunderfx(fn)
cfn(x, idx, val)
with pytest.warns(match="Dynamic output shape operator"):
cfn(x, idx, val)
reason = cfn._backend.subgraph_infos[0].split_reasons
assert len(reason) == 1
assert "Failed while running meta for node with name: setitem" in reason[0].info
Expand Down Expand Up @@ -1719,12 +1783,36 @@ def test_fn(x):
return z + 2

x = torch.tensor([1, 2, 3, 4, 5])
actual = thunderfx(test_fn)(x)
# Without this patch, tolist() would cause graph break. See https://github.com/pytorch/pytorch/pull/163807
with patch("torch._dynamo.config.capture_scalar_outputs", True):
with pytest.warns(match="Example values for arguments are not available"):
actual = thunderfx(test_fn)(x)
expected = test_fn(x)
torch.testing.assert_close(actual, expected)


@pytest.mark.xfail(reason="Unsupported: Dynamo can't retrace autocast enter/exit", raises=AssertionError)
def test_thunderfx_no_example_value_and_autocast():
def fn(x):
with torch.autocast("cpu"):
y = x + 10
z = y.tolist()[0]
return z + 2

x = torch.tensor([1, 2, 3, 4, 5])
# Without this patch, tolist() would cause graph break. See https://github.com/pytorch/pytorch/pull/163807
with patch("torch._dynamo.config.capture_scalar_outputs", True):
with pytest.warns(match="Example values for arguments are not available"):
actual = thunderfx(fn)(x)
expected = fn(x)
torch.testing.assert_close(actual, expected)


# This test addresses the bug reported in https://github.com/Lightning-AI/lightning-thunder/issues/2398
@pytest.mark.skipif(
condition=IS_WINDOWS,
reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094",
)
def test_no_grad_region_split():
def fn(x):
# Thunder supports enclosing torch.set_grad_enabled(False/True)
Expand Down
Loading