Skip to content

Significant output discrepancy between TensorRT engine and ONNX Runtime inference outputs  #4415

@WoodieDudy

Description

@WoodieDudy

Description

When running inference on a TensorRT engine built from an ONNX model, I observe significant discrepancies between TensorRT and ONNX Runtime outputs.
The difference is not minor - mean and max deviations are large across outputs.

Observed output deviations:

Output 0: Mean deviation = 3.970133066177368, Max deviation = 16.034887313842773

Environment

TensorRT Version: 10.9.0.34

NVIDIA GPU: A100 40GB

NVIDIA Driver Version: 550.127.05

CUDA Version: 12.8

CUDNN Version: 9.8.0

Operating System: Ubuntu 24.04

Python Version: 3.12.3

PyTorch Version: 2.6.0+cu124

Container: nvcr.io/nvidia/tensorrt:25.03-py3

Steps To Reproduce

Commands or scripts:

  1. Export ONNX model from PyTorch using the provided script.
    https://gist.github.com/WoodieDudy/f91209ff64d3d84e1fab7d8860f18d42
    Or download onnx file
    https://drive.google.com/file/d/1ItOgKQtcg47lqooq9G1Qi7pLz6qv1fYG/view?usp=sharing
  2. Build TensorRT engine:
    trtexec --onnx=model_static.onnx --saveEngine=model_static.engine
  3. Run the following Python script for inference and comparison:
import numpy as np
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import onnxruntime

class TRTInference:
    def __init__(self, engine_path: str):
        with open(engine_path, 'rb') as f:
            engine_data = f.read()
        runtime = trt.Runtime(trt.Logger(trt.Logger.WARNING))
        self.engine = runtime.deserialize_cuda_engine(engine_data)
        self.context = self.engine.create_execution_context()
        self.stream = cuda.Stream()
        self.input_tensor_indices = []
        self.output_tensor_indices = []
        for i in range(self.engine.num_io_tensors):
            tensor_name = self.engine.get_tensor_name(i)
            mode = self.engine.get_tensor_mode(tensor_name)
            if mode == trt.TensorIOMode.INPUT:
                self.input_tensor_indices.append(i)
            else:
                self.output_tensor_indices.append(i)

    def infer(self, input_tensors: list):
        num_tensors = self.engine.num_io_tensors
        bindings = [None] * num_tensors
        for idx, tensor_index in enumerate(self.input_tensor_indices):
            input_array = input_tensors[idx]
            tensor_name = self.engine.get_tensor_name(tensor_index)
            self.context.set_input_shape(tensor_name, input_array.shape)
            input_mem = cuda.mem_alloc(input_array.nbytes)
            cuda.memcpy_htod_async(input_mem, input_array, self.stream)
            bindings[tensor_index] = int(input_mem)
        output_buffers = {}
        for tensor_index in self.output_tensor_indices:
            tensor_name = self.engine.get_tensor_name(tensor_index)
            out_shape = self.context.get_tensor_shape(tensor_name)
            dtype = trt.nptype(self.engine.get_tensor_dtype(tensor_name))
            nbytes = np.prod(out_shape) * np.dtype(dtype).itemsize
            output_mem = cuda.mem_alloc(int(nbytes))
            bindings[tensor_index] = int(output_mem)
            output_buffers[tensor_index] = (output_mem, out_shape, dtype)
        for i in range(num_tensors):
            tensor_name = self.engine.get_tensor_name(i)
            self.context.set_tensor_address(tensor_name, bindings[i])
        self.context.execute_async_v3(stream_handle=self.stream.handle)
        self.stream.synchronize()
        outputs = []
        for tensor_index in self.output_tensor_indices:
            output_mem, out_shape, dtype = output_buffers[tensor_index]
            host_output = np.empty(out_shape, dtype=dtype)
            cuda.memcpy_dtoh(host_output, output_mem)
            outputs.append(host_output)
        return outputs

batch_size = 16
input_x_np = np.random.rand(batch_size, 240_000).astype(np.float32)
input_xlen_np = np.ones((batch_size,), dtype=np.float32)
engine_path = 'model_static.engine'
inference_engine = TRTInference(engine_path)
trt_outputs = inference_engine.infer([input_x_np, input_xlen_np])
print("TensorRT outputs:")
for idx, output in enumerate(trt_outputs):
    print(f"Output {idx} shape: {output.shape}")

ort_session = onnxruntime.InferenceSession(
    'model_static.onnx',
    providers=['CUDAExecutionProvider'],
    disabled_optimizers=["SkipLayerNormFusion"]
)
input_names = [inp.name for inp in ort_session.get_inputs()]
ort_inputs = {input_names[0]: input_x_np, input_names[1]: input_xlen_np}
ort_outputs = ort_session.run(None, ort_inputs)
print("ONNX Runtime outputs:")
for idx, output in enumerate(ort_outputs):
    print(f"Output {idx} shape: {output.shape}")

for idx, (trt_output, ort_output) in enumerate(zip(trt_outputs, ort_outputs)):
    diff = np.abs(trt_output - ort_output)
    print(f"Output {idx}: Mean deviation = {np.mean(diff)}, Max deviation = {np.max(diff)}")

Have you tried the latest release?: Yes, using container 25.03.

Metadata

Metadata

Assignees

No one assigned

    Labels

    InvestigatingIssue is under investigation by TensorRT devsModule:AccuracyOutput mismatch between TensorRT and other frameworkstriagedIssue has been triaged by maintainers

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions