-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Description
Expected behavior
TVM should be able to import and compile an ONNX PRelu model when the slope input is a broadcastable initializer (e.g. shape (1,1)).
Per the ONNX PRelu operator spec, slope is allowed to be smaller-rank than X as long as it can be broadcast to X (unidirectional broadcasting):
Actual behavior
When importing the attached model with TVM Relax (tvm.relax.frontend.onnx.from_onnx) and applying relax.transform.LegalizeOps, TVM fails with an AssertionError raised from TOPI topi.nn.prelu:
assert len(slope.shape) == 1
Traceback (most recent call last):
File "DLCompilers/bug/tvm/prelu_slope_rank_broadcast/run_repro.py", line 58, in _run_tvm
mod = relax.transform.LegalizeOps()(mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "DLCompilers/tvm/python/tvm/ir/transform.py", line 167, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "python/tvm_ffi/cython/function.pxi", line 904, in tvm_ffi.core.Function.__call__
File "DLCompilers/tvm/src/ir/transform.cc", line 544, in operator()
[](Pass pass, ffi::RValueRef<IRModule> mod) { return pass(*std::move(mod)); });
File "DLCompilers/tvm/src/ir/transform.cc", line 290, in tvm::transform::Pass::operator()(tvm::IRModule) const
return this->operator()(std::move(mod), PassContext::Current());
File "DLCompilers/tvm/src/ir/transform.cc", line 306, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
ret = node->operator()(std::move(mod), pass_ctx);
File "DLCompilers/tvm/src/ir/transform.cc", line 414, in tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
mod = pass_func(std::move(mod), pass_ctx);
File "DLCompilers/tvm/src/relax/transform/legalize_ops.cc", line 416, in operator()
mod = LegalizeMutator(mod, cmap, skip_ops, enable_warning).Transform();
File "DLCompilers/tvm/src/relax/transform/legalize_ops.cc", line 84, in tvm::relax::LegalizeMutator::Transform()
auto updated_func = Downcast<Function>(this->VisitExpr(func));
File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 554, in tvm::relax::ExprMutator::VisitExpr(tvm::RelaxExpr const&)
return builder_->Normalize(ExprFunctor::VisitExpr(expr));
File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 132, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::VisitExpr(tvm::RelaxExpr const&)
return vtable(n, this, std::forward<Args>(args)...);
File "DLCompilers/tvm/include/tvm/node/functor.h", line 102, in tvm::NodeFunctor<tvm::RelaxExpr (tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)>::operator()(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const
return (*func_[n->type_index() - begin_type_index_])(n, std::forward<Args>(args)...);
File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 170, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#8}::_FUN(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)
RELAX_EXPR_FUNCTOR_DISPATCH(FunctionNode);
File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 170, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#8}::operator()(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const
RELAX_EXPR_FUNCTOR_DISPATCH(FunctionNode);
File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 585, in tvm::relax::ExprMutator::VisitExpr_(tvm::relax::FunctionNode const*)
Expr body = this->VisitWithNewScope(op->body, params);
File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 817, in tvm::relax::ExprMutator::VisitWithNewScope(tvm::RelaxExpr const&, tvm::ffi::Optional<tvm::ffi::Array<tvm::relax::Var, void>, void>)
Expr ret = this->VisitExpr(expr);
File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 554, in tvm::relax::ExprMutator::VisitExpr(tvm::RelaxExpr const&)
return builder_->Normalize(ExprFunctor::VisitExpr(expr));
File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 132, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::VisitExpr(tvm::RelaxExpr const&)
return vtable(n, this, std::forward<Args>(args)...);
File "DLCompilers/tvm/include/tvm/node/functor.h", line 102, in tvm::NodeFunctor<tvm::RelaxExpr (tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)>::operator()(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const
return (*func_[n->type_index() - begin_type_index_])(n, std::forward<Args>(args)...);
File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 172, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#10}::_FUN(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)
RELAX_EXPR_FUNCTOR_DISPATCH(SeqExprNode);
File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 172, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#10}::operator()(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const
RELAX_EXPR_FUNCTOR_DISPATCH(SeqExprNode);
File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 628, in tvm::relax::ExprMutator::VisitExpr_(tvm::relax::SeqExprNode const*)
BindingBlock new_block = this->VisitBindingBlock(block);
File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 776, in tvm::relax::ExprMutator::VisitBindingBlock(tvm::relax::BindingBlock const&)
ret = VisitBindingBlock_(node);
File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 734, in tvm::relax::ExprMutator::VisitBindingBlock_(tvm::relax::DataflowBlockNode const*)
this->VisitBinding(binding);
File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 652, in tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*, tvm::relax::ConstantNode const*)
RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(ConstantNode);
File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 554, in tvm::relax::ExprMutator::VisitExpr(tvm::RelaxExpr const&)
return builder_->Normalize(ExprFunctor::VisitExpr(expr));
File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 132, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::VisitExpr(tvm::RelaxExpr const&)
return vtable(n, this, std::forward<Args>(args)...);
File "DLCompilers/tvm/include/tvm/node/functor.h", line 102, in tvm::NodeFunctor<tvm::RelaxExpr (tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)>::operator()(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const
return (*func_[n->type_index() - begin_type_index_])(n, std::forward<Args>(args)...);
File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 171, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#9}::_FUN(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)
RELAX_EXPR_FUNCTOR_DISPATCH(CallNode);
File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 171, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#9}::operator()(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const
RELAX_EXPR_FUNCTOR_DISPATCH(CallNode);
File "DLCompilers/tvm/src/relax/transform/legalize_ops.cc", line 357, in tvm::relax::LegalizeMutator::VisitExpr_(tvm::relax::CallNode const*)
Expr legalized = legalization_func(builder_, visited_call);
File "python/tvm_ffi/cython/function.pxi", line 1058, in tvm_ffi.core.tvm_ffi_callback
File "DLCompilers/tvm/python/tvm/relax/transform/legalize_ops/nn.py", line 493, in _nn_prelu
return bb.call_te(topi.nn.prelu, call.args[0], call.args[1], call.attrs.axis)
File "DLCompilers/tvm/python/tvm/relax/block_builder.py", line 361, in call_te
tir_func, call_args, output_sinfo, tir_vars = gen_call_tir_inputs(func, *args, **kwargs)
File "DLCompilers/tvm/python/tvm/relax/utils.py", line 355, in gen_call_tir_inputs
te_out = func(*te_args, **te_kwargs)
File "DLCompilers/tvm/python/tvm/te/tag.py", line 57, in tagged_fdecl
return fdecl(*args, **kwargs)
File "DLCompilers/tvm/python/tvm/topi/nn/elemwise.py", line 130, in prelu
assert len(slope.shape) == 1
AssertionError
ONNX Runtime can execute the same model successfully.
ONNXRuntime:
[array([[ 1.117622 , -5.554099 , -1.7080084 , -3.217593 , 0.60142773, -0.3002757 , 0.05969319, -0.12815356, -0.7426875 , 1.2047737 , 0.77745306, -5.438607 , 0.76982784, -3.484421 ,
1.0997635 , -3.8377662 , -5.1048746 , -5.466864 , -5.903262 , 0.4335317 , -1.313807 , -3.8877916 ]], dtype=float32)]
[ort] output y shape= (1, 22) dtype= float32 min/max= (-5.903262, 1.2047737)
Environment
Operating System:Ubuntu 22.04.4 LTS
TVM version:0.23.0dev
pytorch version:2.9.1
ort version:1.23.2
onnx version: 1.20.0
python:3.11.14
Steps to reproduce
Download the model and run the following code to obtain the results.
python run_repro.py --model model.onnx --oracle oracle.pkl
from __future__ import annotations
import argparse
import os
import pickle
import sys
import traceback
from pathlib import Path
import numpy as np
def _ensure_repo_tvm() -> None:
repo_root = Path(__file__).resolve().parents[3]
tvm_python = repo_root / "tvm" / "python"
tvm_build = repo_root / "tvm" / "build"
if tvm_python.exists():
sys.path.insert(0, tvm_python.as_posix())
if "TVM_LIBRARY_PATH" not in os.environ and tvm_build.exists():
os.environ["TVM_LIBRARY_PATH"] = tvm_build.as_posix()
def _load_oracle_inputs(path: Path) -> dict[str, np.ndarray]:
obj = pickle.loads(path.read_bytes())
inp = obj.get("input", obj)
if not isinstance(inp, dict):
raise ValueError("oracle.pkl does not contain a dict input")
return {k: np.array(v) for k, v in inp.items()}
def _run_ort(model_path: Path, inputs: dict[str, np.ndarray]) -> None:
import onnxruntime as ort # type: ignore
np.set_printoptions(threshold=np.inf, linewidth=200)
sess = ort.InferenceSession(model_path.as_posix(), providers=["CPUExecutionProvider"])
outs = sess.run(None, inputs)
outs_np = [np.array(v) for v in outs]
print("ONNXRuntime:\n", outs_np)
for o, a in zip(sess.get_outputs(), outs_np):
print("[ort] output", o.name, "shape=", a.shape, "dtype=", a.dtype, "min/max=", (a.min(), a.max()))
def _run_tvm(model_path: Path, inputs: dict[str, np.ndarray]) -> None:
_ensure_repo_tvm()
import onnx # type: ignore
import tvm # type: ignore
from tvm import relax # type: ignore
from tvm.relax.frontend import onnx as rx_onnx # type: ignore
onnx_model = onnx.load(model_path.as_posix())
shape_dict = {k: v.shape for k, v in inputs.items()}
print("[tvm] shape_dict:", shape_dict)
try:
converted = rx_onnx.from_onnx(onnx_model, shape_dict=shape_dict)
mod = converted[0] if isinstance(converted, (list, tuple)) else converted
mod = relax.transform.DecomposeOpsForInference()(mod)
# Expected to FAIL here due to topi.nn.prelu requiring 1-D slope
mod = relax.transform.LegalizeOps()(mod)
mod, params = relax.frontend.detach_params(mod)
tgt = tvm.target.Target("llvm")
pipeline = relax.pipeline.get_default_pipeline(tgt)
with tvm.transform.PassContext(opt_level=3, config={"tir.enable_debug": False}):
_ = relax.build(mod, target=tgt, params=params, relax_pipeline=pipeline)
print("[tvm] UNEXPECTED: succeeded")
except Exception as e:
print("[tvm] FAILED:", type(e).__name__)
tb = traceback.format_exc()
print(tb, end="" if tb.endswith("\n") else "\n")
print("\n[tvm] error repr:\n" + repr(e))
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--model", type=Path, default=Path("model.onnx"))
ap.add_argument("--oracle", type=Path, default=Path("oracle.pkl"))
args = ap.parse_args()
model_path = args.model.resolve()
oracle_path = args.oracle.resolve()
inputs = _load_oracle_inputs(oracle_path)
_run_ort(model_path, inputs)
_run_tvm(model_path, inputs)
return 0
if __name__ == "__main__":
raise SystemExit(main())
Triage
- needs-triage
