Skip to content

[iree_boo] Memory corruption with CUDA graph capture #1296

@mjkvaak-amd

Description

@mjkvaak-amd

Summary

iree_boo backend crashes with malloc corruption during CUDA graph capture. The inductor backend works correctly with the same code. Crash occurs during the torch.cuda.graph() capture phase.

Error

malloc(): unsorted double linked list corrupted
Aborted (core dumped)

Environment

  • PyTorch: 2.9.1+rocm7.1.1.lw.git351ff442
  • IREE Turbine: 3.10.0rc20260204
  • Hardware: AMD MI355X (ROCm 7.1.1)
  • Python 3.12.3

Reproduction

#!/usr/bin/env python3
"""
Minimal reproducer for iree_boo + CUDA graphs memory corruption `Error: malloc(): unsorted double linked list corrupted`
Usage: python iree_cuda_graphs_bug.py
"""
import torch
import torch.nn as nn
import torchvision.models as models
import copy

def train_with_cuda_graphs(backend, num_steps=50):
    print(f"\n{'='*60}")
    print(f"Testing backend: {backend}")
    print(f"{'='*60}")
    
    device = torch.device('cuda:0')
    
    # Model setup
    model = models.resnet50(weights=None).to(device)
    model.train()
    model = torch.compile(model, backend=backend)
    
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    loss_fn = nn.MSELoss()
    
    # Save state for restoration after capture
    model_bak = copy.deepcopy(model.state_dict())
    optimizer_bak = copy.deepcopy(optimizer.state_dict())
    
    # Static inputs for CUDA graph capture
    static_input = torch.randn(4, 3, 224, 224, device=device)
    static_target = torch.randn(4, 1000, device=device)
    
    # Warmup in separate stream
    print("Warmup...")
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(11):
            optimizer.zero_grad(set_to_none=True)
            output = model(static_input)
            loss = loss_fn(output, static_target)
            loss.backward()
            optimizer.step()
    torch.cuda.current_stream().wait_stream(s)
    
    # CUDA graph capture (no optimizer step inside)
    print("Capturing CUDA graph...")
    graph = torch.cuda.CUDAGraph()
    optimizer.zero_grad(set_to_none=True)
    
    with torch.cuda.graph(graph):
        static_output = model(static_input)
        static_loss = loss_fn(static_output, static_target)
        static_loss.backward()
    
    # Restore state
    optimizer.step()
    model.load_state_dict(model_bak)
    optimizer.load_state_dict(optimizer_bak)
    
    print(f"Graph captured. Running {num_steps} replay steps...")
    
    # Replay graph
    for step in range(num_steps):
        optimizer.zero_grad(set_to_none=True)
        graph.replay()
        optimizer.step()
        
        if step % 10 == 0:
            loss_val = static_loss.item()
            print(f"  Step {step}: loss={loss_val:.4f}")
    
    print(f"✓ PASS: {backend} completed {num_steps} steps")
    return True

if __name__ == "__main__":
    backends = ["inductor", "iree_boo"]
    
    for backend in backends:
        try:
            train_with_cuda_graphs(backend)
        except Exception as e:
            print(f"✗ FAIL: {backend} crashed")
            print(f"Error: {e}")
            import traceback
            traceback.print_exc()

which prints

============================================================
Testing backend: inductor
============================================================
Warmup...
Capturing CUDA graph...
Graph captured. Running 50 replay steps...
  Step 0: loss=1.3601
  Step 10: loss=1.3601
  Step 20: loss=1.3601
  Step 30: loss=1.3601
  Step 40: loss=1.3601
✓ PASS: inductor completed 50 steps

============================================================
Testing backend: iree_boo
============================================================
Warmup...
Capturing CUDA graph...
malloc(): unsorted double linked list corrupted
Aborted (core dumped)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions