Skip to content

Accuracy issue in TensorRT 10.12.0 when running an ONNX model on an A10 and L20 GPU (not tested on other GPUs) #4533

@huakaigo

Description

@huakaigo

Description

I’m using onnxruntime-tensorrt(modified version) to run inference for the GDCN model (or a slightly modified version) on an A10 and L20 GPU. I encounter an accuracy issue when enabling dynamic shape, and this issue only occurs when the batch size is 1.

I have reproduced the differences between TensorRT (cu11-10.12.0.36) and ONNX Runtime 1.9.2 on a cropped model, as shown in the script below.

Additional details:

  • When I disable support for the transpose operation in TensorRT, the accuracy returns to normal.(using a modified ONNXRuntime version.)
  • I set the batch size dimension in the optimization profile using the format [min, opt, max]:
    • [1, 1, 1]: results are normal
    • [1, 200, 200]: results are abnormal
Image

Environment

TensorRT Version: 10.12.0.36 and 10.7.0.23

NVIDIA GPU: A10 and L20

NVIDIA Driver Version: 545.23.08

CUDA Version: 11.8

CUDNN Version: 8.0

Onnxruntime Version: 1.19.2

Operating System: centos 7.2

Python Version (if applicable): 3.8

Tensorflow Version (if applicable):

PyTorch Version (if applicable):

Baremetal or Container (if so, version):

Relevant Files

Model link:

trt_transpose_bug.zip

Steps To Reproduce

Run the Python script below to observe the differences between tensorrt-cu11-10.12.0.36 and onnxruntime-1.19.2.

python xxx.py

Commands or scripts:

import tensorrt as trt
import numpy as np
from typing import List, Iterable, NamedTuple, Sequence, Set, Dict, Union
import onnx
from onnx import TensorProto, ModelProto
from dataclasses import dataclass
from collections import OrderedDict
from copy import deepcopy
import sys

import pycuda.driver as cuda
import pycuda.autoinit
import torch
import os

# -------- helper functions begin --------

ONNX_DTYPE_TO_NP_TYPE = {
    1: np.float32,
    6: np.int32,
    7: np.int64,
    8: np.string_,
    10: np.float16,
    TensorProto.STRING: np.string_,
    TensorProto.DOUBLE: np.float64,
}

ONNX_DTYPE_TO_ONNX_TYPE = {
    np.longlong: TensorProto.INT64,
    np.ulonglong: TensorProto.UINT64,
    np.float64: TensorProto.DOUBLE,
    np.float_: TensorProto.FLOAT,
    np.float32: TensorProto.FLOAT,
    np.float16: TensorProto.FLOAT16,
    np.int64: TensorProto.INT64,
    np.int32: TensorProto.INT32,
    np.int16: TensorProto.INT16,
    np.int8: TensorProto.INT8,
    np.uint64: TensorProto.UINT64,
    np.uint32: TensorProto.UINT32,
    np.uint16: TensorProto.UINT16,
    np.uint8: TensorProto.UINT8,
    np.string_: TensorProto.STRING ,
    np.bool_: TensorProto.BOOL,
}

@dataclass
class TensorType:
    shape: List[int]
    dtype: np.dtype

def get_shape_dtype_from_model(onnx_model: ModelProto,
                               bsz: int,
                               model_input_for_comm: Set[str] = set(),
                               comm_bsz: int = 1) -> Dict[str, TensorType]:
    input_shapes_dtype = {}
    for onnx_input in list(onnx_model.graph.input):
        name = onnx_input.name
        input_shape = [dim.dim_value for dim in onnx_input.type.tensor_type.shape.dim]
        input_shape = []
        dim_param_num = 0
        for dim in onnx_input.type.tensor_type.shape.dim:
            if dim.dim_param != "":
                dim_param_num = dim_param_num + 1
                if name in model_input_for_comm:
                    input_shape.append(comm_bsz)
                else:
                    input_shape.append(bsz)
            else:
                input_shape.append(dim.dim_value)
            assert dim_param_num <= 1, f'Input {name} get dynamic dims {dim_param_num}, expected <=1'
        assert onnx_input.type.tensor_type.elem_type in ONNX_DTYPE_TO_NP_TYPE, f'unkown elem_type: [{onnx_input.type.tensor_type.elem_type}]'
        input_shapes_dtype[name] = TensorType(input_shape, ONNX_DTYPE_TO_NP_TYPE[onnx_input.type.tensor_type.elem_type])
    return input_shapes_dtype

def gen_rand_input_for_onnx_model(onnx_model: ModelProto,
                                  bsz: int, comm_bsz: int = 1) -> Dict[str, np.ndarray]:
    """Generate random input for the onnx model

    Args:
        onnx_model (ModelProto): Onnx model
        bsz (int): batch size

    Returns:
        Dict[str, np.ndarray]: random input_name:tensor dict
    """
    # 创建随机输入数据
    def get_comm_tower_inputs():
        comm_embedding_op_name = 'MergeEmbeddingLookupCombineOp__2579'
        model_input_for_comm = [input.name for input in onnx_model.graph.input if comm_embedding_op_name in input.name]
        return model_input_for_comm

    model_input_for_comm = get_comm_tower_inputs()

    input_data = {}
    input_shape_dtypes = get_shape_dtype_from_model(onnx_model, bsz, model_input_for_comm, comm_bsz)
    for onnx_input in onnx_model.graph.input:
        name = onnx_input.name
        dtype = dtype=input_shape_dtypes[name].dtype
        if dtype == np.float32:
            input_data[name] = np.random.uniform(-20, 20.0, input_shape_dtypes[name].shape).astype(dtype)
        else:
            input_data[name] = np.random.randint(0, 3, input_shape_dtypes[name].shape).astype(dtype)
    input_data = dict(sorted(input_data.items()))
    return input_data

def np_error_metrics(hint: str, array1: np.ndarray, array2: np.ndarray):
    """ compare and stat numpy.ndarray
    """
    if array1.dtype.type in [np.string_, np.object_]:
        print(f'{hint} with string dtype, skip comparing.')
        return
    absolute_error = np.abs(array1 - array2)
    array1 = deepcopy(array1)
    # np.where(array1 != 0, absolute_error / np.abs(array1), 0)
    array1[array1 == 0] = 1
    relative_error = absolute_error / np.abs(array1)
    relative_error[array1 == 0] = 0

    max_error = np.max(absolute_error)

    p90_error = np.percentile(absolute_error, 90)
    p95_error = np.percentile(absolute_error, 95)
    p99_error = np.percentile(absolute_error, 99)

    max_relative_error = np.max(relative_error)

    p90_relative_error = np.percentile(relative_error, 90)
    p95_relative_error = np.percentile(relative_error, 95)
    p99_relative_error = np.percentile(relative_error, 99)

    print(f'{hint}')
    print(f'  max absolute error: {max_error}')
    print(f'  p99 absolute error: {p99_error}')
    print(f'  p95 absolute error: {p95_error}')
    print(f'  p90 absolute error: {p90_error}')
    print(f'---------------------------------')
    print(f'  max relative error: {max_relative_error}')
    print(f'  p99 relative error: {p99_relative_error}')
    print(f'  p95 relative error: {p95_relative_error}')
    print(f'  p90 relative error: {p90_relative_error}')

    return {
        'max_error': max_error,
        'p90_error': p90_error,
        'p95_error': p95_error,
        'p99_error': p99_error,
        'max_relative_error': max_relative_error,
        'p90_relative_error': p90_relative_error,
        'p95_relative_error': p95_relative_error,
        'p99_relative_error': p99_relative_error
    }
# -------- helper functions end --------

@dataclass
class BatchSizeRange:
    min: int
    max: int

class TrtInferWrapper():
    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
    EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)

    class ProfileShape(NamedTuple):
        min_shape: Sequence[int]
        opt_shape: Sequence[int]
        max_shape: Sequence[int]

    def __init__(
        self,
        onnx_model_path:str,
        bsz_range: BatchSizeRange
    ):
        assert bsz_range.max >= bsz_range.min, f'bsz_range check failed, expect max >= min'
        self._onnx_model_path = onnx_model_path
        self._bsz_range = bsz_range
        self._profile_settings, self.output_names = self._get_meta_from_onnx_model()
        self._engine = self._build_engine()
        self._context = self._create_context()

    def _get_meta_from_onnx_model(self):
        onnx_model = onnx.load(self._onnx_model_path)
        output_names = [output.name for output in onnx_model.graph.output]
        min_shape_collect = get_shape_dtype_from_model(onnx_model, self._bsz_range.min)
        max_shape_collect = get_shape_dtype_from_model(onnx_model, self._bsz_range.max)
        profile_settings = {}
        for key in min_shape_collect.keys():
            profile_settings[key] = TrtInferWrapper.ProfileShape(min_shape_collect[key], max_shape_collect[key], max_shape_collect[key])
        return profile_settings, output_names

    def _build_engine(self):
        with trt.Builder(self.TRT_LOGGER) as builder, builder.create_network(self.EXPLICIT_BATCH) as network, trt.OnnxParser(network, self.TRT_LOGGER) as parser:
            config = builder.create_builder_config()
            config.set_memory_pool_limit(
                trt.MemoryPoolType.WORKSPACE, 1 << 30
            )

            with open(self._onnx_model_path, 'rb') as model:
                if not parser.parse(model.read()):
                    print("ERROR: Failed to parse ONNX file")
                    for error in range(parser.num_errors):
                        print(parser.get_error(error))
                    return None
            profile = builder.create_optimization_profile()
            for tensor_name, tensor_profile in self._profile_settings.items():
                # print(f"Profile: {tensor_name}: {[ts.shape for ts in tensor_profile]}")
                profile.set_shape(tensor_name, *[ts.shape for ts in tensor_profile])
            config.add_optimization_profile(profile)
            engine = builder.build_engine_with_config(network, config)
            return builder.build_engine_with_config(network, config)

    def _create_context(self):
        return self._engine.create_execution_context()

    def run(self, input_datas: Dict[str, np.ndarray]):
        for input_name, data in input_datas.items():
            self._context.set_input_shape(input_name, data.shape)

        input_device_buffers = {}
        input_host_buffers = {}
        output_device_buffers = {}
        output_host_buffers = {}

        for binding in range(self._engine.num_io_tensors):
            name = self._engine.get_tensor_name(binding)
            mode = self._engine.get_tensor_mode(name)
            size = trt.volume(self._context.get_tensor_shape(name))
            dtype = trt.nptype(self._engine.get_tensor_dtype(name))
            host_mem = cuda.pagelocked_empty(size, dtype)
            device_mem = cuda.mem_alloc(host_mem.nbytes)
            if mode == trt.TensorIOMode.INPUT:
                input_host_buffers[name] = host_mem
                input_device_buffers[name] = device_mem
            else: # trt.TensorIOMode.OUTPUT
                output_host_buffers[name] = host_mem
                output_device_buffers[name] = device_mem
            self._context.set_tensor_address(name, int(device_mem))



        # copy data to pinned memory to enable cudaMemcpyAsync
        list(map(lambda key: np.copyto(input_host_buffers[key], input_datas[key].ravel()), list(input_host_buffers.keys())))
        stream = cuda.Stream()
        # handle input
        list(map(lambda key: cuda.memcpy_htod_async(input_device_buffers[key], input_host_buffers[key], stream), list(input_host_buffers.keys())))
        # inference
        self._context.execute_async_v3(stream.handle)
        # handle output
        list(map(lambda key: cuda.memcpy_dtoh_async(output_host_buffers[key], output_device_buffers[key], stream), list(output_host_buffers.keys())))
        stream.synchronize()
        # reshape output tensors
        for key in output_host_buffers.keys():
            output_host_buffers[key] = output_host_buffers[key].reshape(self._context.get_tensor_shape(key))

        return output_host_buffers

class OrtModelInfer:
    def __init__(self, 
                 model_path:str,
                 EP: Union[str, List[str]],
                 device_id: int = 0,
                 custom_lib: Union[List[str]] = None,
                 init_log_level:int = 2,
                 run_log_level: int = 2):
        import onnxruntime.capi.onnxruntime_pybind11_state as rtpy
        import onnxruntime as rt  # type: ignore
        if isinstance(EP, str):
            EP = [EP]
        for e in EP:
            assert e in set(['CUDAExecutionProvider', 'CPUExecutionProvider', 'TensorrtExecutionProvider'])
        providers = EP
        model = onnx.load(model_path, load_external_data=False)
        self._input_names = [i.name for i in model.graph.input]
        self._output_names = [o.name for o in model.graph.output]
        model = None

        sess_options = rt.SessionOptions()
        if custom_lib is not None and isinstance(custom_lib, str):
            custom_lib = [custom_lib]
        if custom_lib is not None:
            for cl in custom_lib:
                if os.path.exists(cl):
                    sess_options.register_custom_ops_library(cl)
                else:
                    print("No such file '{}'".format(cl), file=sys.stderr)
                    exit(1)

        sess_options.graph_optimization_level = rt.GraphOptimizationLevel(3)
        sess_options.log_severity_level = init_log_level
        self._sess = rt.InferenceSession(model_path,
                                    sess_options=sess_options, providers=providers)
        self._sess.disable_fallback()
        self._run_options = rt.RunOptions()
        self._run_options.log_severity_level = run_log_level

    def run(self, ort_inputs):
        return OrderedDict(zip(self._output_names, self._sess.run(self._output_names, ort_inputs, run_options=self._run_options)))



def test_with_wide_range_shape():
    batch_size_range = BatchSizeRange(1, 200)
    trt_model = TrtInferWrapper(
        onnx_model_path=ONNX_PATH,
        bsz_range=batch_size_range,
    )
    ort_model = OrtModelInfer(
        model_path=ONNX_PATH,
        EP=['CPUExecutionProvider'],
    )
    onnx_model = onnx.load(ONNX_PATH)
    test_batch_size = [200, 1]
    for bsz in test_batch_size:
        data = gen_rand_input_for_onnx_model(onnx_model, bsz, bsz)
        trt_res = trt_model.run(data)
        ort_res = ort_model.run(data)
        print(f"=========bsz:{bsz}=========")
        for k in ort_res.keys():
            np_error_metrics(f"ORT VS TRT of [{k}]", ort_res[k], trt_res[k])

def test_batch_size_1():
    batch_size_range = BatchSizeRange(1, 1)
    trt_model = TrtInferWrapper(
        onnx_model_path=ONNX_PATH,
        bsz_range=batch_size_range,
    )
    ort_model = OrtModelInfer(
        model_path=ONNX_PATH,
        EP=['CPUExecutionProvider'],
    )
    onnx_model = onnx.load(ONNX_PATH)
    test_batch_size = [1]
    for bsz in test_batch_size:
        data = gen_rand_input_for_onnx_model(onnx_model, bsz, bsz)
        trt_res = trt_model.run(data)
        ort_res = ort_model.run(data)
        print(f"=========bsz:{bsz}=========")
        for k in ort_res.keys():
            np_error_metrics(f"ORT VS TRT of [{k}]", ort_res[k], trt_res[k])

def test_with_wide_range_shape_comm_bsz_1():
    batch_size_range = BatchSizeRange(1, 200)
    trt_model = TrtInferWrapper(
        onnx_model_path=ONNX_PATH,
        bsz_range=batch_size_range,
    )
    ort_model = OrtModelInfer(
        model_path=ONNX_PATH,
        EP=['CPUExecutionProvider'],
    )
    onnx_model = onnx.load(ONNX_PATH)
    test_batch_size = [200, 1]
    for bsz in test_batch_size:
        data = gen_rand_input_for_onnx_model(onnx_model, bsz, 1)
        trt_res = trt_model.run(data)
        ort_res = ort_model.run(data)
        print(f"=========bsz:{bsz}=========")
        for k in ort_res.keys():
            np_error_metrics(f"ORT VS TRT of [{k}]", ort_res[k], trt_res[k])

if __name__ == "__main__":
    ONNX_PATH = "trt_transpose_bug.onnx"
    print("==== test_batch_size_1 ====")
    test_batch_size_1()
    print("\n==== test_with_wide_range_shape ====\n")
    test_with_wide_range_shape()
    print("\n==== test_with_wide_range_shape_comm_bsz_1 ====\n")
    test_with_wide_range_shape_comm_bsz_1()
    
    

Have you tried the latest release?: yes

Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt): yes, onnxruntime-1.19.2 on CPUExecutionProvider

Metadata

Metadata

Assignees

Labels

Module:AccuracyOutput mismatch between TensorRT and other frameworksModule:ONNXIssues relating to ONNX usage and import

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions