Skip to content

[Bug][Frontend][ONNX] Split fails to handle uneven splitting with 'num_outputs' in Opset 18 #18751

@huenwei-arch

Description

@huenwei-arch

Expected behavior

The TVM ONNX frontend should correctly implement the "uneven split" logic for the Split operator as defined in Opset 18+. When the num_outputs attribute is provided:

  1. It should calculate block_size = ceil(dimension / num_outputs).
  2. The first $N-1$ outputs should have the size of block_size.
  3. The last output should contain the remainder.

For an input length of 10 and num_outputs=3, the expected output shapes are [4, 4, 2].

Actual behavior

TVM correctly handles uniform splits (e.g., 9/3), but fails to convert the model when an uneven split is required (e.g., 10/3). The frontend throws a conversion error, indicating it cannot handle dimensions that are not perfectly divisible by the number of outputs.

Reproduction Log:

>>> Testing Split: Input Length 9 / 3 parts
  ONNX Runtime shapes: [3, 3, 3]
  TVM shapes:          [3, 3, 3]
  Result: PASS

>>> Testing Split: Input Length 10 / 3 parts
  ONNX Runtime shapes: [4, 4, 2]
  Error converting operator Split, with inputs: [X]
  Result: FAIL (Conversion or Runtime Error)
  Error: Traceback (most recent call last): ... src/relax/ir/block_builder.cc:65: Warning: BlockBuilder destroyed with remaining blocks!

Environment

  • OS: Ubuntu 20.04.6 LTS (Focal Fossa)
  • TVM Version: 0.19.0 (Relax)
  • ONNX Version: 1.18.0
  • ONNX Runtime Version: 1.24.1
  • NumPy Version: 2.4.2

Steps to reproduce

import onnx
from onnx import helper, TensorProto
import numpy as np
import tvm
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx
import onnxruntime as ort

def run_split_test(input_len, num_outputs):
    print(f"\n>>> Testing Split: Input Length {input_len} / {num_outputs} parts")
    
    # 1. Construct ONNX Model
    x_np = np.arange(input_len).astype(np.float32)
    node = helper.make_node(
        'Split',
        inputs=['X'],
        outputs=[f'Y{i}' for i in range(num_outputs)],
        axis=0,
        num_outputs=num_outputs
    )

    graph = helper.make_graph(
        [node],
        'split_test',
        [helper.make_tensor_value_info('X', TensorProto.FLOAT, [input_len])],
        [helper.make_tensor_value_info(f'Y{i}', TensorProto.FLOAT, [None]) for i in range(num_outputs)]
    )

    model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)])

    # 2. Reference output (ORT)
    sess = ort.InferenceSession(model.SerializeToString())
    ort_outs = sess.run(None, {'X': x_np})
    ort_shapes = [o.shape[0] for o in ort_outs]
    print(f"  ONNX Runtime shapes: {ort_shapes}")

    # 3. TVM output
    try:
        tvm_mod = from_onnx(model)
        target = tvm.target.Target("llvm")
        exe = relax.build(tvm_mod, target)
        vm = relax.VirtualMachine(exe, tvm.cpu())
        
        tvm_outs = vm["main"](tvm.nd.array(x_np))
        tvm_shapes = [o.asnumpy().shape[0] for o in tvm_outs]
        print(f"  TVM shapes:          {tvm_shapes}")
        
        if tvm_shapes == ort_shapes:
            print("  Result: PASS")
        else:
            print("  Result: FAIL (Shape Mismatch)")
    except Exception as e:
        print(f"  Result: FAIL (Conversion or Runtime Error)")
        print(f"  Error: {str(e)[:100]}...")

if __name__ == "__main__":
    # Case 1: Uniform split (Should PASS)
    run_split_test(input_len=9, num_outputs=3)
    
    # Case 2: Non-uniform split (Should FAIL)
    run_split_test(input_len=10, num_outputs=3)

Triage

  • relax:frontend:onnx
  • needs-triage

cc @KJlaccHoeUM9l

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions