Skip to content

[Bug] PRelu import/compile fails when slope initializer is broadcastable rank-2 (1,1): topi.nn.prelu requires 1-D slope #18607

@dutZ1855

Description

@dutZ1855

Expected behavior

TVM should be able to import and compile an ONNX PRelu model when the slope input is a broadcastable initializer (e.g. shape (1,1)).

Per the ONNX PRelu operator spec, slope is allowed to be smaller-rank than X as long as it can be broadcast to X (unidirectional broadcasting):

Actual behavior

For the following model,
Image

When importing the attached model with TVM Relax (tvm.relax.frontend.onnx.from_onnx) and applying relax.transform.LegalizeOps, TVM fails with an AssertionError raised from TOPI topi.nn.prelu:

  • assert len(slope.shape) == 1
Traceback (most recent call last):
  File "DLCompilers/bug/tvm/prelu_slope_rank_broadcast/run_repro.py", line 58, in _run_tvm
    mod = relax.transform.LegalizeOps()(mod)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "DLCompilers/tvm/python/tvm/ir/transform.py", line 167, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "python/tvm_ffi/cython/function.pxi", line 904, in tvm_ffi.core.Function.__call__
  File "DLCompilers/tvm/src/ir/transform.cc", line 544, in operator()
    [](Pass pass, ffi::RValueRef<IRModule> mod) { return pass(*std::move(mod)); });
    
  File "DLCompilers/tvm/src/ir/transform.cc", line 290, in tvm::transform::Pass::operator()(tvm::IRModule) const
    return this->operator()(std::move(mod), PassContext::Current());
    
  File "DLCompilers/tvm/src/ir/transform.cc", line 306, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
    ret = node->operator()(std::move(mod), pass_ctx);
    
  File "DLCompilers/tvm/src/ir/transform.cc", line 414, in tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
    mod = pass_func(std::move(mod), pass_ctx);
    
  File "DLCompilers/tvm/src/relax/transform/legalize_ops.cc", line 416, in operator()
    mod = LegalizeMutator(mod, cmap, skip_ops, enable_warning).Transform();
    
  File "DLCompilers/tvm/src/relax/transform/legalize_ops.cc", line 84, in tvm::relax::LegalizeMutator::Transform()
    auto updated_func = Downcast<Function>(this->VisitExpr(func));
    
  File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 554, in tvm::relax::ExprMutator::VisitExpr(tvm::RelaxExpr const&)
    return builder_->Normalize(ExprFunctor::VisitExpr(expr));
    
  File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 132, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::VisitExpr(tvm::RelaxExpr const&)
    return vtable(n, this, std::forward<Args>(args)...);
    
  File "DLCompilers/tvm/include/tvm/node/functor.h", line 102, in tvm::NodeFunctor<tvm::RelaxExpr (tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)>::operator()(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const
    return (*func_[n->type_index() - begin_type_index_])(n, std::forward<Args>(args)...);
    
  File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 170, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#8}::_FUN(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)
    RELAX_EXPR_FUNCTOR_DISPATCH(FunctionNode);
    
  File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 170, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#8}::operator()(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const
    RELAX_EXPR_FUNCTOR_DISPATCH(FunctionNode);
    
  File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 585, in tvm::relax::ExprMutator::VisitExpr_(tvm::relax::FunctionNode const*)
    Expr body = this->VisitWithNewScope(op->body, params);
    
  File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 817, in tvm::relax::ExprMutator::VisitWithNewScope(tvm::RelaxExpr const&, tvm::ffi::Optional<tvm::ffi::Array<tvm::relax::Var, void>, void>)
    Expr ret = this->VisitExpr(expr);
    
  File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 554, in tvm::relax::ExprMutator::VisitExpr(tvm::RelaxExpr const&)
    return builder_->Normalize(ExprFunctor::VisitExpr(expr));
    
  File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 132, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::VisitExpr(tvm::RelaxExpr const&)
    return vtable(n, this, std::forward<Args>(args)...);
    
  File "DLCompilers/tvm/include/tvm/node/functor.h", line 102, in tvm::NodeFunctor<tvm::RelaxExpr (tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)>::operator()(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const
    return (*func_[n->type_index() - begin_type_index_])(n, std::forward<Args>(args)...);
    
  File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 172, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#10}::_FUN(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)
    RELAX_EXPR_FUNCTOR_DISPATCH(SeqExprNode);
    
  File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 172, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#10}::operator()(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const
    RELAX_EXPR_FUNCTOR_DISPATCH(SeqExprNode);
    
  File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 628, in tvm::relax::ExprMutator::VisitExpr_(tvm::relax::SeqExprNode const*)
    BindingBlock new_block = this->VisitBindingBlock(block);
    
  File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 776, in tvm::relax::ExprMutator::VisitBindingBlock(tvm::relax::BindingBlock const&)
    ret = VisitBindingBlock_(node);
    
  File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 734, in tvm::relax::ExprMutator::VisitBindingBlock_(tvm::relax::DataflowBlockNode const*)
    this->VisitBinding(binding);
    
  File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 652, in tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*, tvm::relax::ConstantNode const*)
    RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(ConstantNode);
    
  File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 554, in tvm::relax::ExprMutator::VisitExpr(tvm::RelaxExpr const&)
    return builder_->Normalize(ExprFunctor::VisitExpr(expr));
    
  File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 132, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::VisitExpr(tvm::RelaxExpr const&)
    return vtable(n, this, std::forward<Args>(args)...);
    
  File "DLCompilers/tvm/include/tvm/node/functor.h", line 102, in tvm::NodeFunctor<tvm::RelaxExpr (tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)>::operator()(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const
    return (*func_[n->type_index() - begin_type_index_])(n, std::forward<Args>(args)...);
    
  File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 171, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#9}::_FUN(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)
    RELAX_EXPR_FUNCTOR_DISPATCH(CallNode);
    
  File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 171, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#9}::operator()(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const
    RELAX_EXPR_FUNCTOR_DISPATCH(CallNode);
    
  File "DLCompilers/tvm/src/relax/transform/legalize_ops.cc", line 357, in tvm::relax::LegalizeMutator::VisitExpr_(tvm::relax::CallNode const*)
    Expr legalized = legalization_func(builder_, visited_call);
    
  File "python/tvm_ffi/cython/function.pxi", line 1058, in tvm_ffi.core.tvm_ffi_callback
  File "DLCompilers/tvm/python/tvm/relax/transform/legalize_ops/nn.py", line 493, in _nn_prelu
    return bb.call_te(topi.nn.prelu, call.args[0], call.args[1], call.attrs.axis)
    
  File "DLCompilers/tvm/python/tvm/relax/block_builder.py", line 361, in call_te
    tir_func, call_args, output_sinfo, tir_vars = gen_call_tir_inputs(func, *args, **kwargs)
    
  File "DLCompilers/tvm/python/tvm/relax/utils.py", line 355, in gen_call_tir_inputs
    te_out = func(*te_args, **te_kwargs)
    
  File "DLCompilers/tvm/python/tvm/te/tag.py", line 57, in tagged_fdecl
    return fdecl(*args, **kwargs)
    
  File "DLCompilers/tvm/python/tvm/topi/nn/elemwise.py", line 130, in prelu
    assert len(slope.shape) == 1
    
AssertionError

ONNX Runtime can execute the same model successfully.

ONNXRuntime:
 [array([[ 1.117622  , -5.554099  , -1.7080084 , -3.217593  ,  0.60142773, -0.3002757 ,  0.05969319, -0.12815356, -0.7426875 ,  1.2047737 ,  0.77745306, -5.438607  ,  0.76982784, -3.484421  ,
         1.0997635 , -3.8377662 , -5.1048746 , -5.466864  , -5.903262  ,  0.4335317 , -1.313807  , -3.8877916 ]], dtype=float32)]
[ort] output y shape= (1, 22) dtype= float32 min/max= (-5.903262, 1.2047737)

Environment

Operating System:Ubuntu 22.04.4 LTS
TVM version:0.23.0dev
pytorch version:2.9.1
ort version:1.23.2
onnx version: 1.20.0
python:3.11.14

Steps to reproduce

model.zip

Download the model and run the following code to obtain the results.
python run_repro.py --model model.onnx --oracle oracle.pkl

from __future__ import annotations

import argparse
import os
import pickle
import sys
import traceback
from pathlib import Path

import numpy as np


def _ensure_repo_tvm() -> None:
    repo_root = Path(__file__).resolve().parents[3]
    tvm_python = repo_root / "tvm" / "python"
    tvm_build = repo_root / "tvm" / "build"
    if tvm_python.exists():
        sys.path.insert(0, tvm_python.as_posix())
    if "TVM_LIBRARY_PATH" not in os.environ and tvm_build.exists():
        os.environ["TVM_LIBRARY_PATH"] = tvm_build.as_posix()


def _load_oracle_inputs(path: Path) -> dict[str, np.ndarray]:
    obj = pickle.loads(path.read_bytes())
    inp = obj.get("input", obj)
    if not isinstance(inp, dict):
        raise ValueError("oracle.pkl does not contain a dict input")
    return {k: np.array(v) for k, v in inp.items()}


def _run_ort(model_path: Path, inputs: dict[str, np.ndarray]) -> None:
    import onnxruntime as ort  # type: ignore

    np.set_printoptions(threshold=np.inf, linewidth=200)
    sess = ort.InferenceSession(model_path.as_posix(), providers=["CPUExecutionProvider"])
    outs = sess.run(None, inputs)
    outs_np = [np.array(v) for v in outs]
    print("ONNXRuntime:\n", outs_np)
    for o, a in zip(sess.get_outputs(), outs_np):
        print("[ort] output", o.name, "shape=", a.shape, "dtype=", a.dtype, "min/max=", (a.min(), a.max()))


def _run_tvm(model_path: Path, inputs: dict[str, np.ndarray]) -> None:
    _ensure_repo_tvm()
    import onnx  # type: ignore
    import tvm  # type: ignore
    from tvm import relax  # type: ignore
    from tvm.relax.frontend import onnx as rx_onnx  # type: ignore

    onnx_model = onnx.load(model_path.as_posix())
    shape_dict = {k: v.shape for k, v in inputs.items()}
    print("[tvm] shape_dict:", shape_dict)
    try:
        converted = rx_onnx.from_onnx(onnx_model, shape_dict=shape_dict)
        mod = converted[0] if isinstance(converted, (list, tuple)) else converted
        mod = relax.transform.DecomposeOpsForInference()(mod)
        # Expected to FAIL here due to topi.nn.prelu requiring 1-D slope
        mod = relax.transform.LegalizeOps()(mod)
        mod, params = relax.frontend.detach_params(mod)
        tgt = tvm.target.Target("llvm")
        pipeline = relax.pipeline.get_default_pipeline(tgt)
        with tvm.transform.PassContext(opt_level=3, config={"tir.enable_debug": False}):
            _ = relax.build(mod, target=tgt, params=params, relax_pipeline=pipeline)
        print("[tvm] UNEXPECTED: succeeded")
    except Exception as e:
        print("[tvm] FAILED:", type(e).__name__)
        tb = traceback.format_exc()
        print(tb, end="" if tb.endswith("\n") else "\n")
        print("\n[tvm] error repr:\n" + repr(e))


def main() -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", type=Path, default=Path("model.onnx"))
    ap.add_argument("--oracle", type=Path, default=Path("oracle.pkl"))
    args = ap.parse_args()

    model_path = args.model.resolve()
    oracle_path = args.oracle.resolve()
    inputs = _load_oracle_inputs(oracle_path)

    _run_ort(model_path, inputs)
    _run_tvm(model_path, inputs)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())

Triage

  • needs-triage

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions