-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Open
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Expected behavior
The TVM ONNX frontend should correctly implement the "uneven split" logic for the Split operator as defined in Opset 18+. When the num_outputs attribute is provided:
- It should calculate
block_size = ceil(dimension / num_outputs). - The first
$N-1$ outputs should have the size ofblock_size. - The last output should contain the remainder.
For an input length of 10 and num_outputs=3, the expected output shapes are [4, 4, 2].
Actual behavior
TVM correctly handles uniform splits (e.g., 9/3), but fails to convert the model when an uneven split is required (e.g., 10/3). The frontend throws a conversion error, indicating it cannot handle dimensions that are not perfectly divisible by the number of outputs.
Reproduction Log:
>>> Testing Split: Input Length 9 / 3 parts
ONNX Runtime shapes: [3, 3, 3]
TVM shapes: [3, 3, 3]
Result: PASS
>>> Testing Split: Input Length 10 / 3 parts
ONNX Runtime shapes: [4, 4, 2]
Error converting operator Split, with inputs: [X]
Result: FAIL (Conversion or Runtime Error)
Error: Traceback (most recent call last): ... src/relax/ir/block_builder.cc:65: Warning: BlockBuilder destroyed with remaining blocks!
Environment
- OS: Ubuntu 20.04.6 LTS (Focal Fossa)
- TVM Version: 0.19.0 (Relax)
- ONNX Version: 1.18.0
- ONNX Runtime Version: 1.24.1
- NumPy Version: 2.4.2
Steps to reproduce
import onnx
from onnx import helper, TensorProto
import numpy as np
import tvm
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx
import onnxruntime as ort
def run_split_test(input_len, num_outputs):
print(f"\n>>> Testing Split: Input Length {input_len} / {num_outputs} parts")
# 1. Construct ONNX Model
x_np = np.arange(input_len).astype(np.float32)
node = helper.make_node(
'Split',
inputs=['X'],
outputs=[f'Y{i}' for i in range(num_outputs)],
axis=0,
num_outputs=num_outputs
)
graph = helper.make_graph(
[node],
'split_test',
[helper.make_tensor_value_info('X', TensorProto.FLOAT, [input_len])],
[helper.make_tensor_value_info(f'Y{i}', TensorProto.FLOAT, [None]) for i in range(num_outputs)]
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)])
# 2. Reference output (ORT)
sess = ort.InferenceSession(model.SerializeToString())
ort_outs = sess.run(None, {'X': x_np})
ort_shapes = [o.shape[0] for o in ort_outs]
print(f" ONNX Runtime shapes: {ort_shapes}")
# 3. TVM output
try:
tvm_mod = from_onnx(model)
target = tvm.target.Target("llvm")
exe = relax.build(tvm_mod, target)
vm = relax.VirtualMachine(exe, tvm.cpu())
tvm_outs = vm["main"](tvm.nd.array(x_np))
tvm_shapes = [o.asnumpy().shape[0] for o in tvm_outs]
print(f" TVM shapes: {tvm_shapes}")
if tvm_shapes == ort_shapes:
print(" Result: PASS")
else:
print(" Result: FAIL (Shape Mismatch)")
except Exception as e:
print(f" Result: FAIL (Conversion or Runtime Error)")
print(f" Error: {str(e)[:100]}...")
if __name__ == "__main__":
# Case 1: Uniform split (Should PASS)
run_split_test(input_len=9, num_outputs=3)
# Case 2: Non-uniform split (Should FAIL)
run_split_test(input_len=10, num_outputs=3)Triage
- relax:frontend:onnx
- needs-triage
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug