-
Notifications
You must be signed in to change notification settings - Fork 80
Open
Description
Summary
torch.compile(..., backend="iree_boo_inductor") fails at runtime when the model contains nn.Dropout. The failure is a lowering error involving aten.rand.default (“should have been handled in replace_random.py”). The same model runs successfully with backend="inductor" (default) and backend="iree_boo". This makes it impossible to compile dropout-equipped models (e.g. many torchvision classifiers) with iree_boo_inductor.
Error
torch._inductor.exc.InductorError: LoweringException: AssertionError: should have been handled in replace_random.py target: aten.rand.default args[0]: [2, 8] kwargs: {'dtype': torch.float32, 'device': device(type='cuda', index=0), 'pin_memory': False}
Environment
PyTorch: 2.9.1+rocm7.1.1.lw.git351ff442
IREE Turbine: built from source on commit 96cbd0c
Hardware: AMD MI355X (ROCm 7.1.1)
Python 3.12.3
Reproduction
"""
Minimal repro: torch.compile(..., backend="iree_boo_inductor") fails to run models with Dropout.
Usage: python iree_dropout_repro.py
"""
import torch
import torch.nn as nn
class SmallModelWithDropout(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 8, 3, padding=1)
self.pool = nn.AdaptiveAvgPool2d(1)
self.dropout = nn.Dropout(0.5)
self.fc = nn.Linear(8, 10)
def forward(self, x):
x = torch.relu(self.conv(x))
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.dropout(x)
x = self.fc(x)
return x
def test(backend):
device = torch.device("cuda:0")
model = SmallModelWithDropout().to(device)
model = torch.compile(model, backend=backend)
model.train()
x = torch.randn(2, 3, 32, 32, device=device)
y = model(x)
loss = y.sum()
loss.backward()
return "PASS"
if __name__ == "__main__":
backends = ["inductor", "iree_boo", "iree_boo_inductor"]
for b in backends:
try:
test(b)
print(f"{b}: PASS")
except Exception as e:
print(f"{b}: FAIL - {type(e).__name__} - {e}")
prints
inductor: PASS
iree_boo: PASS
/opt/venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py:312: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
warnings.warn(
iree_boo_inductor: FAIL - InductorError - LoweringException: AssertionError: should have been handled in replace_random.py
target: aten.rand.default
args[0]: [2, 8]
kwargs: {'dtype': torch.float32, 'device': device(type='cuda', index=0), 'pin_memory': False}
```
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels