-
Notifications
You must be signed in to change notification settings - Fork 80
Open
Description
Summary
iree_boo backend crashes with malloc corruption during CUDA graph capture. The inductor backend works correctly with the same code. Crash occurs during the torch.cuda.graph() capture phase.
Error
malloc(): unsorted double linked list corrupted
Aborted (core dumped)
Environment
- PyTorch: 2.9.1+rocm7.1.1.lw.git351ff442
- IREE Turbine: 3.10.0rc20260204
- Hardware: AMD MI355X (ROCm 7.1.1)
- Python 3.12.3
Reproduction
#!/usr/bin/env python3
"""
Minimal reproducer for iree_boo + CUDA graphs memory corruption `Error: malloc(): unsorted double linked list corrupted`
Usage: python iree_cuda_graphs_bug.py
"""
import torch
import torch.nn as nn
import torchvision.models as models
import copy
def train_with_cuda_graphs(backend, num_steps=50):
print(f"\n{'='*60}")
print(f"Testing backend: {backend}")
print(f"{'='*60}")
device = torch.device('cuda:0')
# Model setup
model = models.resnet50(weights=None).to(device)
model.train()
model = torch.compile(model, backend=backend)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()
# Save state for restoration after capture
model_bak = copy.deepcopy(model.state_dict())
optimizer_bak = copy.deepcopy(optimizer.state_dict())
# Static inputs for CUDA graph capture
static_input = torch.randn(4, 3, 224, 224, device=device)
static_target = torch.randn(4, 1000, device=device)
# Warmup in separate stream
print("Warmup...")
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(11):
optimizer.zero_grad(set_to_none=True)
output = model(static_input)
loss = loss_fn(output, static_target)
loss.backward()
optimizer.step()
torch.cuda.current_stream().wait_stream(s)
# CUDA graph capture (no optimizer step inside)
print("Capturing CUDA graph...")
graph = torch.cuda.CUDAGraph()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(graph):
static_output = model(static_input)
static_loss = loss_fn(static_output, static_target)
static_loss.backward()
# Restore state
optimizer.step()
model.load_state_dict(model_bak)
optimizer.load_state_dict(optimizer_bak)
print(f"Graph captured. Running {num_steps} replay steps...")
# Replay graph
for step in range(num_steps):
optimizer.zero_grad(set_to_none=True)
graph.replay()
optimizer.step()
if step % 10 == 0:
loss_val = static_loss.item()
print(f" Step {step}: loss={loss_val:.4f}")
print(f"✓ PASS: {backend} completed {num_steps} steps")
return True
if __name__ == "__main__":
backends = ["inductor", "iree_boo"]
for backend in backends:
try:
train_with_cuda_graphs(backend)
except Exception as e:
print(f"✗ FAIL: {backend} crashed")
print(f"Error: {e}")
import traceback
traceback.print_exc()
which prints
============================================================
Testing backend: inductor
============================================================
Warmup...
Capturing CUDA graph...
Graph captured. Running 50 replay steps...
Step 0: loss=1.3601
Step 10: loss=1.3601
Step 20: loss=1.3601
Step 30: loss=1.3601
Step 40: loss=1.3601
✓ PASS: inductor completed 50 steps
============================================================
Testing backend: iree_boo
============================================================
Warmup...
Capturing CUDA graph...
malloc(): unsorted double linked list corrupted
Aborted (core dumped)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels