Skip to content

[Bug] Segfault in tvm.compile (Relax→TIR, CUDA target) inside tir::transform::InjectPTXLDG32 / PTXRewriter::VisitStmt_(BufferStore) when compiling torch.export model returning (tril, triu) tuple #18612

@tinywisdom

Description

@tinywisdom

Summary

I am seeing a hard segfault (no Python exception) during tvm.compile(...) for a CUDA target. The crash consistently occurs inside the TIR pass:

  • tvm::tir::transform::InjectPTXLDG32(bool)
  • tvm::tir::PTXRewriter::VisitStmt_(tvm::tir::BufferStoreNode const*)
  • tvm::tir::BufferStore::BufferStore(...)

The input IRModule is produced by converting a PyTorch torch.export program using tvm.relax.frontend.torch.from_exported_program. The PyTorch model is intentionally small (Linear(4,4)) and returns a tuple of tensors: (torch.tril(x), torch.triu(x)).

This looks like a bug in the InjectPTXLDG32 rewrite logic, or an unsafe assumption in the pass leading to a null/invalid BufferStore construction.


Environment

From the repro output:

  • TVM version: 0.22.0
  • TVM commit: 9dbf3f22ff6f44962472f9af310fda368ca85ef2
  • LLVM: 17.0.6
  • Python: 3.10.16 (from stack paths)
  • NumPy: 2.2.6
  • PyTorch: 2.9.0+cu128
  • CUDA GPU: NVIDIA RTX A6000 (sm_86)

Target string used:

cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32

Minimal Repro Script

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import torch
import torch.nn as nn
import tvm
from tvm import tir


def print_env_info():
    print("==== Environment Info ====")
    print("TVM version:", getattr(tvm, "__version__", "unknown"))
    try:
        print("TVM git commit:", tvm.support.libinfo().get("GIT_COMMIT_HASH", "unknown"))
    except Exception:
        print("TVM git commit: unknown")
    try:
        print("TVM LLVM version:", tvm.support.libinfo().get("LLVM_VERSION", "unknown"))
    except Exception:
        print("TVM LLVM version: unknown")
    print("Python (numpy) version:", np.__version__)
    print("PyTorch version:", torch.__version__)
    print("CUDA available (torch):", torch.cuda.is_available())
    if torch.cuda.is_available():
        try:
            print("CUDA device:", torch.cuda.get_device_name(0))
        except Exception:
            pass
    print("==========================\n")


class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(4, 4)

    def forward(self, x):
        x = self.linear(x)
        return torch.tril(x), torch.triu(x)


def export_to_relax(mod: nn.Module, x: torch.Tensor) -> tvm.IRModule:
    mod = mod.to("cpu").eval()
    x = x.to("cpu")
    ep = torch.export.export(mod, (x,))
    from tvm.relax.frontend.torch import from_exported_program
    return from_exported_program(ep)


def main():
    print_env_info()

    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is required for this repro, but torch.cuda.is_available() is False")

    target_str = "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32"
    target = tvm.target.Target(target_str)

    relax_pipeline = "default"
    tir_pipeline = "default"

    model = MyModel()
    x = torch.zeros((1, 4), dtype=torch.float32)

    print("[repro] exporting torch -> relax ...")
    ir_mod = export_to_relax(model, x)

    disabled_pass = [
        "DeadCodeElimination",
        "CanonicalizeBindings",
        "Simplify",
        "UnrollLoop",
        "VectorizeLoop",
        "StorageRewrite",
        "RemoveNoOp",
        "LoopPartition",
    ]

    pass_config = {
        "relax.FuseOps.max_depth": 2,
        "relax.lift_transform_params.consume_params": 1,
        "tir.disable_storage_rewrite": 1,
        "tir.disable_vectorize": 1,
        "tir.instrument_bound_checkers": 1,
        "tir.merge_static_smem": 1,
        "tir.noalias": 1,
        "tir.ptx_ldg32": 1,
        "tir.use_async_copy": 1,
    }

    pc_kwargs = {
        "opt_level": 3,
        "disabled_pass": disabled_pass,
        "config": pass_config,
    }

    print("[repro] target:", target)
    print("[repro] relax_pipeline:", relax_pipeline)
    print("[repro] tir_pipeline:", tir_pipeline)
    print("[repro] opt_level:", pc_kwargs["opt_level"])
    print("[repro] disabled_pass:", disabled_pass)
    print("[repro] PassContext.config keys:", sorted(pass_config.keys()))
    print("[repro] compiling with tvm.compile ...")

    with tvm.transform.PassContext(**pc_kwargs):
        _ = tvm.compile(
            ir_mod,
            target=target,
            relax_pipeline=relax_pipeline,
            tir_pipeline=tir_pipeline,
        )

    print("[repro] compile finished (no crash).")


if __name__ == "__main__":
    main()

Actual Behavior

Segfault during compilation:

!!!!!!! Segfault encountered !!!!!!!
...
tvm::tir::BufferStore::BufferStore(...)
tvm::tir::PTXRewriter::VisitStmt_(tvm::tir::BufferStoreNode const*)
...
tvm::tir::transform::InjectPTXLDG32(bool)
Segmentation fault (core dumped)

This is a hard crash (core dumped), not a recoverable error.


Expected Behavior

tvm.compile(...) should either:

  • successfully compile the module, or
  • raise a normal Python exception / diagnostic if some pass config is invalid,

but it should not segfault.

Triage

  • needs-triage
  • bug

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions