Skip to content

Precision issues of TensorRT 10.11.0.33 when running model on GPU 4090 #4543

@CallmeZhangChenchen

Description

@CallmeZhangChenchen

Description

I encountered precision issues when running a program on gpu.
The results of onnxruntime are consistent with those of polygraphy run --trt
However, the reasoning results directly using python tensorrt code are inconsistent with them

Environment

TensorRT Version: 10.11.0.33

NVIDIA GPU: 4090

NVIDIA Driver Version: 550.163.01

CUDA Version: tensorrt_cu12, tensorrt_cu12_libs, nvidia-cuda-runtime-cu12.6.77

CUDNN Version:nvidia-cudnn-cu12 , 9.5.1.17

Operating System:

Python Version (if applicable): Python 3.10.16

Tensorflow Version (if applicable): None

PyTorch Version (if applicable):torch 2.7.0

Baremetal or Container (if so, version): None

Relevant Files

Model link:
test.onnx

----------------- input -----------------
{'name': 'lm_input', 'shape': [1, 'seq_len', 1024], 'type': 'tensor(float)'}
{'name': 'att_mask', 'shape': [1, 'seq_len', 'seq_len'], 'type': 'tensor(bool)'}
----------------- output -----------------
{'name': 'xs',
 'shape': ['LayerNormalizationxs_dim_0', 'seq_len', 1024],
 'type': 'tensor(float)'}
{'name': 'r_att_cache',
 'shape': ['Concatr_att_cache_dim_0',
           'Concatr_att_cache_dim_1',
           'seq_len',
           'Concatr_att_cache_dim_3'],
 'type': 'tensor(float)'}

Steps To Reproduce

Commands or scripts:
att_mask.npy
lm_input.npy

onnxruntime result

import numpy as np
import onnxruntime as ort
input1 = np.load('lm_input.npy').astype(np.float32)
input2 = np.load('att_mask.npy')     
sess = ort.InferenceSession("test.onnx")
input_names = [input.name for input in sess.get_inputs()]
outputs = sess.run(None,  {input_names[0]: input1,input_names[1]: input2})
print("推理结果:", outputs[0])
推理结果: [[[ 0.5224824   0.04726911 -0.09089257 ... -0.02409136 -0.01106405
   -0.07198416]
  [ 0.6034443   0.32908097  0.10189726 ...  0.24078083  0.33037812
    0.13956162]
  [ 0.56458145 -0.17910141 -0.15065426 ... -0.03973707  0.14174853
    0.55276424]
  ...
  [ 0.41438544  0.73719823 -0.8061858  ... -0.3429834   0.48010287
   -0.23982278]
  [ 0.4162167   0.6774794  -0.78476673 ... -0.34373513  0.47167927
   -0.20082474]
  [ 0.49580193 -0.6017233   0.60250455 ...  0.720528   -1.15546
   -1.007982  ]]]

tensorrt result

import tensorrt as trt


def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
    network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    logger = trt.Logger(trt.Logger.INFO)
    builder = trt.Builder(logger)
    network = builder.create_network(network_flags)
    parser = trt.OnnxParser(network, logger)
    config = builder.create_builder_config()
    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32)  # 4GB
    if fp16:
        config.set_flag(trt.BuilderFlag.FP16)
    profile = builder.create_optimization_profile()
    # load onnx model
    with open(onnx_model, "rb") as f:
        if not parser.parse(f.read()):
            for error in range(parser.num_errors):
                print(parser.get_error(error))
            raise ValueError('failed to parse {}'.format(onnx_model))
    # set input shapes
    for i in range(len(trt_kwargs['input_names'])):
        profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
    # import pdb;pdb.set_trace()
    tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
    # set input and output data type
    for i in range(network.num_inputs):
        input_tensor = network.get_input(i)
        input_tensor.dtype = tensor_dtype
    for i in range(network.num_outputs):
        output_tensor = network.get_output(i)
        output_tensor.dtype = tensor_dtype
    config.add_optimization_profile(profile)
    engine_bytes = builder.build_serialized_network(network, config)
    # save trt engine
    with open(trt_model, "wb") as f:
        f.write(engine_bytes)
    # logging.info("Succesfully convert onnx to trt...")

def get_first_trt_kwargs():
    min_shape = [(1, 1, 1024), (1,1,1)]
    opt_shape = [(1, 100, 1024), (1, 100, 100)]
    max_shape = [(1, 200, 1024), (1,200,200)]
    input_names = ['lm_input', 'att_mask']
    return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}


convert_onnx_to_trt('test.trt', get_first_trt_kwargs(),'test.onnx',False)

import torch
with open('test.trt', 'rb') as f:
    trt_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())

device = 'cuda'
trt_context = trt_engine.create_execution_context()
trt_stream = torch.cuda.stream(torch.cuda.Stream(device))

import numpy as np
lm_input = torch.from_numpy(np.load('lm_input.npy').astype(np.float32)).to(device)
att_mask = torch.from_numpy(np.load('att_mask.npy')).to(device)

att_cache_return = torch.empty((14, 16, lm_input.size(1), 128), device=lm_input.device, dtype=torch.float32)

torch.cuda.current_stream().synchronize()
with trt_stream:
    trt_context.set_input_shape('lm_input', (1, lm_input.size(1), 1024))
    trt_context.set_input_shape('att_mask', (1, lm_input.size(1), lm_input.size(1)))
    data_ptrs = [lm_input.float().contiguous().data_ptr(),att_mask.contiguous().data_ptr(),lm_input.data_ptr(),att_cache_return.data_ptr()]
    for i, j in enumerate(data_ptrs):
        trt_context.set_tensor_address(trt_engine.get_tensor_name(i), j)
    # run trt engine
    assert trt_context.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
    torch.cuda.current_stream().synchronize()

print(lm_input)
[08/05/2025-19:23:57] [TRT] [I] Loaded engine size: 774 MiB
[08/05/2025-19:23:57] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +77, now: CPU 0,
 GPU 848 (MiB)
tensor([[[ 0.5207,  0.0565, -0.0968,  ..., -0.0262, -0.0120, -0.0723],
         [ 0.5955, -0.4025, -1.5997,  ..., -0.3384, -0.5182, -0.6315],
         [ 0.5913,  0.1236, -0.1837,  ..., -0.9531,  0.0454, -0.7702],
         ...,
         [ 0.5218, -0.4751,  0.3485,  ...,  0.0597, -1.3055, -0.5576],
         [ 0.5116, -1.2225,  1.2767,  ..., -1.4294, -0.1401, -0.6012],
         [ 0.4806,  0.4716,  0.4966,  ..., -0.4568,  0.1386, -0.8197]]],
       device='cuda:0')

Have you tried the latest release?:

TensorRT Version: 10.11.0.33

Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt):

model_inputs.json

polygraphy run test.onnx --trt --onnxrt --trt-outputs mark all --onnx-outputs mark all --atol 1e-2 --rtol 1e-3 --fail-fast --input-shapes lm_input:[1,88,1024] att_mask:[1,88,88] --load-inputs ./model_inputs.json

[I] Accuracy Summary | trt-runner-N0-08/05/25-19:26:54 vs. onnxrt-runner-N0-08/05/25-19:26:54 | Passed: 1/1 iterations | Pass Rate: 100.0%
[I] PASSED | Runtime: 69.595s | Command: /root/miniconda3/envs/cosyvoice2/bin/polygraphy run test.onnx --trt --onnxrt --trt-outputs mark all --onnx-outputs mark all --atol 1e-2 --rtol 1e-3 --fail-fast --input-shapes lm_input:[1,88,1024] att_mask:[1,88,88] --load-inputs ./model_inputs.json


polygraphy run test.trt --trt --input-shapes lm_input:[1,88,1024] att_mask:[1,88,88] --load-inputs ./model_inputs.json --save-outputs output.json --model-type engine

There is a very strange [W]

[I] RUNNING | Command: /root/miniconda3/envs/cosyvoice2/bin/polygraphy run test.trt --trt --input-shapes lm_input:[1,88,1024] att_mask:[
1,88,88] --load-inputs model_inputs.json --save-outputs output.json --model-type engine
[I] Loading input data from model_inputs.json
[I] trt-runner-N0-08/05/25-19:29:00     | Activating and starting inference
[I] Loading bytes from /models/cosyvoice_models/panxb/CosyVoice-300M-25Hz-spk-online-all-adddy176-jiqing65-except-miaobo-pengfei-day0612
-filter-finetune-miaobo-pengfei-shth/zcc_code/test.trt
[W] Input tensor: att_mask | Buffer dtype (bool) does not match expected input dtype (float32), attempting to cast. 
[I] trt-runner-N0-08/05/25-19:29:00    
    ---- Inference Input(s) ----
    {lm_input [dtype=float32, shape=(1, 88, 1024)],
     att_mask [dtype=float32, shape=(1, 88, 88)]}
[I] trt-runner-N0-08/05/25-19:29:00    
    ---- Inference Output(s) ----
    {xs [dtype=float32, shape=(1, 88, 1024)],
     r_att_cache [dtype=float32, shape=(14, 16, 88, 128)]}
[I] trt-runner-N0-08/05/25-19:29:00     | Completed 1 iteration(s) in 17.1 ms | Average inference time: 17.1 ms.
[I] Saving inference results to output.json
import json
with open('output.json', 'r') as f:
  data = json.load(f)

import base64
import numpy as np
data = base64.b64decode(data['lst'][0][1][0]['outputs']['xs']['values']['array'].encode(), validate=True)
import io
infile = io.BytesIO(data)
aaaa = np.load(infile, allow_pickle=False)
print(aaaa)

result as same as onnxruntime

array([[[ 0.5224831 ,  0.04727471, -0.09088739, ..., -0.02410248,
         -0.01106691, -0.07198079],
        [ 0.60342693,  0.32877445,  0.10219606, ...,  0.2401557 ,
          0.33079317,  0.1393766 ],
        [ 0.5645711 , -0.17879786, -0.15065011, ..., -0.03974693,
          0.14182723,  0.55259246],
        ...,
        [ 0.4144164 ,  0.7374819 , -0.8063961 , ..., -0.3434482 ,
          0.48050952, -0.24006712],
        [ 0.4162233 ,  0.67745477, -0.78478736, ..., -0.34435624,
          0.47210816, -0.20117404],
        [ 0.49581778, -0.6018964 ,  0.6024814 , ...,  0.7206898 ,
         -1.1551262 , -1.0074697 ]]], dtype=float32)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Module:AccuracyOutput mismatch between TensorRT and other frameworks

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions