-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Description
Description
For the following simple onnx model,
the results Y produced by onnxruntime are as follows:
ONNXRuntime:
[[[[ True True True True]
[ True True True True]
[ True True True True]
[ True True True True]]
[[ True True True True]
[ True True True True]
[ True True True True]
[ True True True True]]
[[ True True True True]
[ True True True True]
[ True True True True]
[ True True True True]]]]However, when I run it using tensorrt, the results Y are as follows:
TensorRT:
[[[[ True True True True]
[False False False False]
[False False False False]
[False False False False]]
[[ True True True True]
[False False False False]
[ True False False False]
[False False False False]]
[[ True True True True]
[False False False False]
[ True True True True]
[False False False False]]]]64.6% elements are mismatched.
I have verified that the results CAST_Y for both onnxruntime and tensorrt are identical.
This issue is simliiar to 4511. Both issues only contain the Or and Cast operators.
Environment
TensorRT Version: 10.12.0.36
NVIDIA GPU: GeForce RTX 3080
NVIDIA Driver Version: 535.183.01
CUDA Version: 12.2
CUDNN Version: none
Operating System: ubuntu 20.04
Python Version (if applicable): 3.12.9
Relevant Files
Model link:
Steps To Reproduce
This issue can be reproduced by the following code with the model in the attachment.
from typing import Dict, List, Literal, Optional
import sys
import os
import numpy as np
import onnx
import onnxruntime
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import argparse
import pickle
def test():
onnx_model = onnx.load('333.onnx')
with open("inputs.pkl", "rb") as fp:
inputs = pickle.load(fp)
try:
ort_session = onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)
ort_output = ort_session.run([], inputs)
except Exception as e:
print(e)
print("This model cannot be executed by onnxruntime!")
sys.exit(1)
print("ONNXRuntime:\n", ort_output[1])
#--------------------------------------------------------
trt_logger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(trt_logger, '')
builder = trt.Builder(trt_logger)
network = builder.create_network(flags=1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, trt_logger)
with open('333.onnx', 'rb') as model_file:
if not parser.parse(model_file.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
sys.exit(1)
config = builder.create_builder_config()
serialized_engine = builder.build_serialized_network(network, config)
if serialized_engine == None:
sys.exit(1)
with open("engine.trt", "wb") as f:
f.write(serialized_engine)
with open("engine.trt", "rb") as f, trt.Runtime(trt_logger) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
context = engine.create_execution_context()
inputs_trt, outputs_trt, bindings = [], [], []
stream = cuda.Stream()
input_name = []
output_shape_dtype = []
#------------------------------------------------------------
for binding in engine:
size = trt.volume(engine.get_tensor_shape(binding))
dtype = trt.nptype(engine.get_tensor_dtype(binding))
host_mem = cuda.pagelocked_empty(size, dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)
bindings.append({'name':binding, 'address':int(device_mem)})
if engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT:
inputs_trt.append({'host': host_mem, 'device': device_mem})
input_name.append(binding)
else:
outputs_trt.append({'host': host_mem, 'device': device_mem})
output_shape = engine.get_tensor_shape(binding)
output_shape_dtype.append({'shape':output_shape, 'dtype':dtype})
for i, input_mem in enumerate(inputs_trt):
inp = np.ravel(inputs[input_name[i]])
np.copyto(input_mem['host'], inp)
cuda.memcpy_htod_async(input_mem['device'], input_mem['host'], stream)
for bind in bindings:
name = bind['name']
addr = bind['address']
context.set_tensor_address(name, addr)
context.execute_async_v3(stream_handle=stream.handle)
trt_output = []
for i, output_mem in enumerate(outputs_trt):
cuda.memcpy_dtoh_async(output_mem['host'], output_mem['device'], stream)
out_shape = output_shape_dtype[i]['shape']
out = output_mem['host'].reshape(out_shape)
trt_output.append(out)
stream.synchronize()
print("TensorRT: \n", trt_output[1])
assert len(ort_output) == len(trt_output), "Unequal number of outputs"
np.testing.assert_allclose(trt_output[0], ort_output[0], rtol=0.1, atol=0.1) #CAST_Y, GOOD
np.testing.assert_allclose(trt_output[1], ort_output[1], rtol=0.1, atol=0.1) #Y, BAD
if __name__ == "__main__":
test()
Commands or scripts:
Have you tried the latest release?: yes
Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt): the mode can be executed by onnxruntime.