Skip to content

Commit ac5508a

Browse files
committed
Add test
1 parent 5756fd8 commit ac5508a

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

thunder/tests/test_dynamo.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,37 @@ def find_target_module(model, target_module_name):
844844
assert isinstance(n.target, Symbol) or callable(n.target)
845845

846846

847+
@requiresCUDA
848+
@pytest.mark.parametrize("op", [torch.sin, torch.sinc])
849+
def test_checkpoint_memory_use(op):
850+
import torch.utils.checkpoint as checkpoint
851+
852+
def fn(x):
853+
return op(op(op(op(x))))
854+
855+
def checkpoint_fn(x):
856+
return checkpoint.checkpoint(fn, x, use_reentrant=False)
857+
858+
initial_mem = torch.cuda.memory_allocated()
859+
860+
x = torch.randn((1024 // 4, 1024, 1024), device="cuda", requires_grad=True)
861+
jfn = thunderfx(checkpoint_fn)
862+
y = jfn(x)
863+
864+
peak_mem_usage = torch.cuda.max_memory_allocated() - initial_mem
865+
866+
y_ref = fn(x)
867+
torch.testing.assert_close(y, y_ref)
868+
869+
if op == torch.sin:
870+
assert peak_mem_usage == x.nbytes * 2
871+
else:
872+
assert peak_mem_usage == x.nbytes * 3
873+
# Make sure the checkpointed region falled back to PyTorch
874+
sinfo = jfn._backend.subgraph_infos[-1]
875+
assert any(n.name.startswith("inductor") for n in sinfo.split_graph_module.graph.nodes)
876+
877+
847878
@instantiate(
848879
dtypes=NOTHING,
849880
executors=[DynamoThunderExecutor],

0 commit comments

Comments
 (0)