-
Notifications
You must be signed in to change notification settings - Fork 107
Open
Labels
Description
🐛 Bug
When ThunderFX delegates activation checkpointing operations to Inductor-compiled submodules, apparently the checkpointing mechanism fails to execute properly, causing excessive memory usage due to retained intermediate activations. This may apply to all models that containing some Thunder-unsupported ops in activation checkpointed regions, causing e.g. #2501.
Note. When Thunder encounters an unsupported op in a checkpointed region, the entire region will go to Inductor. Huge memory usage in status quo could be a good sign of Thunder's poor coverage.
Code sample
import torch
from torch.utils.checkpoint import checkpoint
from thunder.dynamo import thunderfx
def fn(x):
return x.sinc().sinc().sinc()
def checkpoint_fn(x):
return checkpoint(fn, x, use_reentrant=False)
for compiler in [thunderfx, torch.compile]:
torch.cuda.reset_peak_memory_stats()
assert torch.cuda.memory_allocated() == 0
x = torch.randn((1024, 1024, 1024), device="cuda", requires_grad=True) # 4 GB
y = compiler(checkpoint_fn)(x)
del x, y
peak_mem_usage = torch.cuda.max_memory_allocated()
print(f"{compiler.__name__}: {peak_mem_usage / 1024 / 1024 / 1024} GB")
thunderfx: 16.0 GB
compile: 8.0 GB
Input for thunder.dynamo.ThunderCompiler
:
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[1024, 1024, 1024]"):
l_x_ = L_x_
# File: /opt/pytorch/lightning-thunder/tmp/repro.py:9 in checkpoint_fn, code: return checkpoint(fn, x, use_reentrant=False)
wrap_body_0 = self.wrap_body_0
tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = False); wrap_body_0 = l_x_ = None
getitem: "f32[1024, 1024, 1024]" = tag_activation_checkpoint[0]; tag_activation_checkpoint = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[1024, 1024, 1024]"):
# File: /opt/pytorch/lightning-thunder/tmp/repro.py:6 in fn, code: return x.sinc().sinc().sinc()
sinc: "f32[1024, 1024, 1024]" = l_x_.sinc(); l_x_ = None
sinc_1: "f32[1024, 1024, 1024]" = sinc.sinc(); sinc = None
sinc_2: "f32[1024, 1024, 1024]" = sinc_1.sinc(); sinc_1 = None
return (sinc_2,)
Output:
class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[1024, 1024, 1024]"):
# No stacktrace found for following nodes
inductor_0 = self.inductor_0(l_x_); l_x_ = None
thunder_1 = self.thunder_1(inductor_0); inductor_0 = None
return (thunder_1,)
class inductor_0(torch.nn.Module):
def forward(self, l_x_: "f32[1024, 1024, 1024]"):
# File: /opt/pytorch/lightning-thunder/tmp/repro.py:9 in checkpoint_fn, code: return checkpoint(fn, x, use_reentrant=False)
wrap_body_0 = self.wrap_body_0
tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = False); wrap_body_0 = l_x_ = None
return tag_activation_checkpoint
class _orig_mod(torch.nn.Module):
def forward(self, l_x_: "f32[1024, 1024, 1024]"):
# File: /opt/pytorch/lightning-thunder/tmp/repro.py:9 in checkpoint_fn, code: return checkpoint(fn, x, use_reentrant=False)
wrap_body_0 = self.wrap_body_0
tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = False); wrap_body_0 = l_x_ = None
return tag_activation_checkpoint
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[1024, 1024, 1024]"):
# File: /opt/pytorch/lightning-thunder/tmp/repro.py:6 in fn, code: return x.sinc().sinc().sinc()
sinc: "f32[1024, 1024, 1024]" = l_x_.sinc(); l_x_ = None
sinc_1: "f32[1024, 1024, 1024]" = sinc.sinc(); sinc = None
sinc_2: "f32[1024, 1024, 1024]" = sinc_1.sinc(); sinc_1 = None
return (sinc_2,)
class thunder_1(torch.nn.Module):
def forward(self, tag_activation_checkpoint):
# File: /opt/pytorch/lightning-thunder/tmp/repro.py:9 in checkpoint_fn, code: return checkpoint(fn, x, use_reentrant=False)
getitem: "f32[1024, 1024, 1024]" = tag_activation_checkpoint[0]; tag_activation_checkpoint = None
return getitem
class _model(torch.nn.Module):
def forward(self, tag_activation_checkpoint):
# File: /opt/pytorch/lightning-thunder/tmp/repro.py:9 in checkpoint_fn, code: return checkpoint(fn, x, use_reentrant=False)
getitem: "f32[1024, 1024, 1024]" = tag_activation_checkpoint[0]; tag_activation_checkpoint = None
return getitem