-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Open
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Summary
I am seeing a hard segfault (no Python exception) during tvm.compile(...) for a CUDA target. The crash consistently occurs inside the TIR pass:
tvm::tir::transform::InjectPTXLDG32(bool)tvm::tir::PTXRewriter::VisitStmt_(tvm::tir::BufferStoreNode const*)tvm::tir::BufferStore::BufferStore(...)
The input IRModule is produced by converting a PyTorch torch.export program using tvm.relax.frontend.torch.from_exported_program. The PyTorch model is intentionally small (Linear(4,4)) and returns a tuple of tensors: (torch.tril(x), torch.triu(x)).
This looks like a bug in the InjectPTXLDG32 rewrite logic, or an unsafe assumption in the pass leading to a null/invalid BufferStore construction.
Environment
From the repro output:
- TVM version:
0.22.0 - TVM commit:
9dbf3f22ff6f44962472f9af310fda368ca85ef2 - LLVM:
17.0.6 - Python:
3.10.16(from stack paths) - NumPy:
2.2.6 - PyTorch:
2.9.0+cu128 - CUDA GPU:
NVIDIA RTX A6000(sm_86)
Target string used:
cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32
Minimal Repro Script
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.nn as nn
import tvm
from tvm import tir
def print_env_info():
print("==== Environment Info ====")
print("TVM version:", getattr(tvm, "__version__", "unknown"))
try:
print("TVM git commit:", tvm.support.libinfo().get("GIT_COMMIT_HASH", "unknown"))
except Exception:
print("TVM git commit: unknown")
try:
print("TVM LLVM version:", tvm.support.libinfo().get("LLVM_VERSION", "unknown"))
except Exception:
print("TVM LLVM version: unknown")
print("Python (numpy) version:", np.__version__)
print("PyTorch version:", torch.__version__)
print("CUDA available (torch):", torch.cuda.is_available())
if torch.cuda.is_available():
try:
print("CUDA device:", torch.cuda.get_device_name(0))
except Exception:
pass
print("==========================\n")
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(4, 4)
def forward(self, x):
x = self.linear(x)
return torch.tril(x), torch.triu(x)
def export_to_relax(mod: nn.Module, x: torch.Tensor) -> tvm.IRModule:
mod = mod.to("cpu").eval()
x = x.to("cpu")
ep = torch.export.export(mod, (x,))
from tvm.relax.frontend.torch import from_exported_program
return from_exported_program(ep)
def main():
print_env_info()
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this repro, but torch.cuda.is_available() is False")
target_str = "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32"
target = tvm.target.Target(target_str)
relax_pipeline = "default"
tir_pipeline = "default"
model = MyModel()
x = torch.zeros((1, 4), dtype=torch.float32)
print("[repro] exporting torch -> relax ...")
ir_mod = export_to_relax(model, x)
disabled_pass = [
"DeadCodeElimination",
"CanonicalizeBindings",
"Simplify",
"UnrollLoop",
"VectorizeLoop",
"StorageRewrite",
"RemoveNoOp",
"LoopPartition",
]
pass_config = {
"relax.FuseOps.max_depth": 2,
"relax.lift_transform_params.consume_params": 1,
"tir.disable_storage_rewrite": 1,
"tir.disable_vectorize": 1,
"tir.instrument_bound_checkers": 1,
"tir.merge_static_smem": 1,
"tir.noalias": 1,
"tir.ptx_ldg32": 1,
"tir.use_async_copy": 1,
}
pc_kwargs = {
"opt_level": 3,
"disabled_pass": disabled_pass,
"config": pass_config,
}
print("[repro] target:", target)
print("[repro] relax_pipeline:", relax_pipeline)
print("[repro] tir_pipeline:", tir_pipeline)
print("[repro] opt_level:", pc_kwargs["opt_level"])
print("[repro] disabled_pass:", disabled_pass)
print("[repro] PassContext.config keys:", sorted(pass_config.keys()))
print("[repro] compiling with tvm.compile ...")
with tvm.transform.PassContext(**pc_kwargs):
_ = tvm.compile(
ir_mod,
target=target,
relax_pipeline=relax_pipeline,
tir_pipeline=tir_pipeline,
)
print("[repro] compile finished (no crash).")
if __name__ == "__main__":
main()Actual Behavior
Segfault during compilation:
!!!!!!! Segfault encountered !!!!!!!
...
tvm::tir::BufferStore::BufferStore(...)
tvm::tir::PTXRewriter::VisitStmt_(tvm::tir::BufferStoreNode const*)
...
tvm::tir::transform::InjectPTXLDG32(bool)
Segmentation fault (core dumped)
This is a hard crash (core dumped), not a recoverable error.
Expected Behavior
tvm.compile(...) should either:
- successfully compile the module, or
- raise a normal Python exception / diagnostic if some pass config is invalid,
but it should not segfault.
Triage
- needs-triage
- bug
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug