Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,18 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
)
elif node.name.startswith("submod"): # For inductor
graph_module = getattr(split_gm, node.name)
jit_fn = torch_inductor(graph_module)

class Wrapped(torch.nn.Module):
def __init__(self, gm):
super().__init__()
self.gm = gm

def forward(self, *a):
return self.gm(*a)

# Make sure Inductor does not skip graph_module's compilation by wrapping it
# See https://github.com/Lightning-AI/lightning-thunder/issues/2527#issuecomment-3345877210
jit_fn = torch_inductor(Wrapped(graph_module))
# Update the node name from "submod_*" to "inductor_*" for more user-friendly names
update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn)
submodule_to_compiled_fns[getattr(original_split_gm, node_name)] = CompiledFunction(
Expand Down
47 changes: 47 additions & 0 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,24 @@ def func(x):
assert any(target.startswith("thunder_") for target in targets) # Verify that the submodules have name `thunder_*`


@instantiate(dtypes=NOTHING)
def test_inductor_fallback(executor, device, dtype):
x = torch.randn(3, 3, device=device, dtype=dtype)

def func(x):
return x.sinc().cos().sinc().sinc()

def trivial_compile(model, *args, **kwargs):
return model

cfunc = thunderfx(func)
with patch("torch._inductor.compile_fx.compile_fx", side_effect=trivial_compile) as mock_call:
cfunc(x)

# Once for sinc() and once for sinc().sinc()
assert mock_call.call_count == 2


@instantiate(
dtypes=NOTHING,
executors=[DynamoThunderExecutor],
Expand Down Expand Up @@ -844,6 +862,35 @@ def find_target_module(model, target_module_name):
assert isinstance(n.target, Symbol) or callable(n.target)


@requiresCUDA
@pytest.mark.parametrize("op", [torch.sin, torch.sinc])
def test_checkpoint_memory_use(op):
import torch.utils.checkpoint as checkpoint

def fn(x):
return op(op(op(op(x))))

def checkpoint_fn(x):
return checkpoint.checkpoint(fn, x, use_reentrant=False)

initial_mem = torch.cuda.memory_allocated()

x = torch.randn((128, 128), device="cuda", requires_grad=True)
jfn = thunderfx(checkpoint_fn)
y = jfn(x)

peak_mem_usage = torch.cuda.max_memory_allocated() - initial_mem

y_ref = fn(x)
torch.testing.assert_close(y, y_ref)

assert peak_mem_usage == x.nbytes * 2
if op == torch.sinc:
# Make sure the checkpointed region fell back to PyTorch
sinfo = jfn._backend.subgraph_infos[-1]
assert any(n.name.startswith("inductor") for n in sinfo.split_graph_module.graph.nodes)


@instantiate(
dtypes=NOTHING,
executors=[DynamoThunderExecutor],
Expand Down
Loading