Skip to content

Activation checkpoint not working inside Inductor-compiled submodules #2527

@shino16

Description

@shino16

🐛 Bug

When ThunderFX delegates activation checkpointing operations to Inductor-compiled submodules, apparently the checkpointing mechanism fails to execute properly, causing excessive memory usage due to retained intermediate activations. This may apply to all models that containing some Thunder-unsupported ops in activation checkpointed regions, causing e.g. #2501.

Note. When Thunder encounters an unsupported op in a checkpointed region, the entire region will go to Inductor. Huge memory usage in status quo could be a good sign of Thunder's poor coverage.

Code sample

import torch
from torch.utils.checkpoint import checkpoint
from thunder.dynamo import thunderfx

def fn(x):
    return x.sinc().sinc().sinc()

def checkpoint_fn(x):
    return checkpoint(fn, x, use_reentrant=False)

for compiler in [thunderfx, torch.compile]:
    torch.cuda.reset_peak_memory_stats()
    assert torch.cuda.memory_allocated() == 0

    x = torch.randn((1024, 1024, 1024), device="cuda", requires_grad=True)  # 4 GB
    y = compiler(checkpoint_fn)(x)
    del x, y

    peak_mem_usage = torch.cuda.max_memory_allocated()
    print(f"{compiler.__name__}: {peak_mem_usage / 1024 / 1024 / 1024} GB")
thunderfx: 16.0 GB
compile: 8.0 GB

Input for thunder.dynamo.ThunderCompiler:

class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[1024, 1024, 1024]"):
        l_x_ = L_x_
        
         # File: /opt/pytorch/lightning-thunder/tmp/repro.py:9 in checkpoint_fn, code: return checkpoint(fn, x, use_reentrant=False)
        wrap_body_0 = self.wrap_body_0
        tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = False);  wrap_body_0 = l_x_ = None
        getitem: "f32[1024, 1024, 1024]" = tag_activation_checkpoint[0];  tag_activation_checkpoint = None
        return (getitem,)
        
    class wrap_body_0(torch.nn.Module):
        def forward(self, l_x_: "f32[1024, 1024, 1024]"):
             # File: /opt/pytorch/lightning-thunder/tmp/repro.py:6 in fn, code: return x.sinc().sinc().sinc()
            sinc: "f32[1024, 1024, 1024]" = l_x_.sinc();  l_x_ = None
            sinc_1: "f32[1024, 1024, 1024]" = sinc.sinc();  sinc = None
            sinc_2: "f32[1024, 1024, 1024]" = sinc_1.sinc();  sinc_1 = None
            return (sinc_2,)

Output:

class GraphModule(torch.nn.Module):
    def forward(self, l_x_: "f32[1024, 1024, 1024]"):
        # No stacktrace found for following nodes
        inductor_0 = self.inductor_0(l_x_);  l_x_ = None
        thunder_1 = self.thunder_1(inductor_0);  inductor_0 = None
        return (thunder_1,)
        
    class inductor_0(torch.nn.Module):
        def forward(self, l_x_: "f32[1024, 1024, 1024]"):
             # File: /opt/pytorch/lightning-thunder/tmp/repro.py:9 in checkpoint_fn, code: return checkpoint(fn, x, use_reentrant=False)
            wrap_body_0 = self.wrap_body_0
            tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = False);  wrap_body_0 = l_x_ = None
            return tag_activation_checkpoint
            
        class _orig_mod(torch.nn.Module):
            def forward(self, l_x_: "f32[1024, 1024, 1024]"):
                 # File: /opt/pytorch/lightning-thunder/tmp/repro.py:9 in checkpoint_fn, code: return checkpoint(fn, x, use_reentrant=False)
                wrap_body_0 = self.wrap_body_0
                tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = False);  wrap_body_0 = l_x_ = None
                return tag_activation_checkpoint
                
            class wrap_body_0(torch.nn.Module):
                def forward(self, l_x_: "f32[1024, 1024, 1024]"):
                     # File: /opt/pytorch/lightning-thunder/tmp/repro.py:6 in fn, code: return x.sinc().sinc().sinc()
                    sinc: "f32[1024, 1024, 1024]" = l_x_.sinc();  l_x_ = None
                    sinc_1: "f32[1024, 1024, 1024]" = sinc.sinc();  sinc = None
                    sinc_2: "f32[1024, 1024, 1024]" = sinc_1.sinc();  sinc_1 = None
                    return (sinc_2,)
                    
    class thunder_1(torch.nn.Module):
        def forward(self, tag_activation_checkpoint):
             # File: /opt/pytorch/lightning-thunder/tmp/repro.py:9 in checkpoint_fn, code: return checkpoint(fn, x, use_reentrant=False)
            getitem: "f32[1024, 1024, 1024]" = tag_activation_checkpoint[0];  tag_activation_checkpoint = None
            return getitem
            
        class _model(torch.nn.Module):
            def forward(self, tag_activation_checkpoint):
                 # File: /opt/pytorch/lightning-thunder/tmp/repro.py:9 in checkpoint_fn, code: return checkpoint(fn, x, use_reentrant=False)
                getitem: "f32[1024, 1024, 1024]" = tag_activation_checkpoint[0];  tag_activation_checkpoint = None
                return getitem

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions