Skip to content

TensorRT fails to infer the shape of the output for a valid onnx model. #4471

@coffezhou

Description

@coffezhou

Description

For the following valid onnx model,
Image
TensorRT fails to infer the shape of the output. The shape of the final_output is:

(1, 0, 32, 7)

However, when I execute this model using onnxruntime, the shape of the final_output is:

(1, 1, 32, 7)

This issue further leads an error to gpu memory allocation.

 device_mem = cuda.mem_alloc(host_mem.nbytes)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
pycuda._driver.LogicError: cuMemAlloc failed: invalid argument

Environment

TensorRT Version: 10.11.0.33

NVIDIA GPU: GeForce RTX 3080

NVIDIA Driver Version: 535.183.01

CUDA Version: 12.2

CUDNN Version: none

Operating System: ubuntu 20.04

Python Version (if applicable): 3.12.9

Tensorflow Version (if applicable): none

PyTorch Version (if applicable): none

Baremetal or Container (if so, version): none

Steps To Reproduce

This bug can be reproduced by the following code with the model in the attachment. As shown in the code, the model can be executed by onnxruntime.

from typing import Dict, List, Literal, Optional
import sys
import os

import numpy as np
import onnx
import onnxruntime
from onnx import ModelProto, TensorProto, helper, mapping

import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit

import argparse
import pickle


def test():
    onnx_model = onnx.load("1111.onnx")
    
    with open("inputs.pkl", "rb") as fp:
        inputs = pickle.load(fp)

    try:
        ort_session = onnxruntime.InferenceSession(
            onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
        )
        ort_output = ort_session.run([], inputs)
    except Exception as e:
        print(e)
        print("This model cannot be executed by onnxruntime!")
        sys.exit(1)
    
    print("ONNXRuntime:\n", ort_output[0].shape)
    
    #--------------------------------------------------------
        
    trt_logger = trt.Logger(trt.Logger.WARNING)
    trt.init_libnvinfer_plugins(trt_logger, '')
    builder = trt.Builder(trt_logger)
    network = builder.create_network(flags=1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))

    parser = trt.OnnxParser(network, trt_logger)
    with open("1111.onnx", 'rb') as model_file:
        if not parser.parse(model_file.read()):
            for error in range(parser.num_errors):
                print(parser.get_error(error))
            sys.exit(1)
            
    config = builder.create_builder_config()
    serialized_engine = builder.build_serialized_network(network, config)
    
    if serialized_engine == None:
        sys.exit(1)
    
    with open("engine.trt", "wb") as f:
        f.write(serialized_engine)
        
    with open("engine.trt", "rb") as f, trt.Runtime(trt_logger) as runtime:
        engine = runtime.deserialize_cuda_engine(f.read())
    
    #------------------------------------------------------------
    for binding in engine:
        print(binding, engine.get_tensor_shape(binding))

    
if __name__ == "__main__":
    test()

testcast.zip

Commands or scripts:

Have you tried the latest release?: yes

Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt): the mode can be executed by onnxruntime.

Metadata

Metadata

Assignees

Labels

Module:ONNXIssues relating to ONNX usage and importtriagedIssue has been triaged by maintainers

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions