-
Notifications
You must be signed in to change notification settings - Fork 193
Open
Labels
Description
Hi @i-riyad, could please you take a moment to look at this question? sincerely appreciate
full code are there https://github.com/PonyPinkPie/export in example.py
1、Environment
- linux:Ubuntu 20.04.3 LTS
- CUDA:Cuda compilation tools, release 12.8, V12.8.61
- cudnn:ubuntu2004-9.4.0_1.0-1_amd64
- TensorRT:TensorRT-10.13.3.9
- CPU: Intel(R) Core(TM) i9-10900X CPU @ 3.70GHz
- GPU: RTX4090D
2、torch2onnx and onnx2trt code
my onnx file is fp32 mode without any other processes such as ptq or qat. it seems that the flag config.set_flag(trt.BuilderFlag.FP8) does not effect.
| engine mode | FPS |
|---|---|
| fp16 | 2819 |
| fp8 | 982 |
def torch2onnx(
model=None,
dummy_input=None,
onnx_model_name=None,
dynamic_shape=False,
opset_version=17,
# do_constant_folding=False,
do_constant_folding=True,
verbose=False):
if isinstance(dummy_input, tuple):
dummy_input = list(dummy_input)
dummy_input = to(dummy_input, 'cuda')
model.eval().cuda()
with torch.no_grad():
output = model(dummy_input)
assert not isinstance(dummy_input, dict), 'input should not be dict.'
assert not isinstance(output, dict), 'output should not be dict'
input_names = get_names(dummy_input, 'input')
# print(input_names)
output_names = get_names(output, 'output')
dynamic_axes = dict()
for name, tensor in zip(input_names+output_names,
flatten(dummy_input)+flatten(output)):
dynamic_axes[name] = list(range(tensor.dim())) if dynamic_shape else [0]
# input(f"dynamic_axes = {dynamic_axes}")
torch.onnx.export(
model,
dummy_input,
onnx_model_name,
input_names=input_names,
output_names=output_names,
opset_version=opset_version,
do_constant_folding=do_constant_folding,
verbose=verbose,
dynamic_axes=dynamic_axes)
torch.cuda.empty_cache()
def onnx2trt(
model,
log_level='ERROR',
max_batch_size=1,
min_input_shapes=None,
max_input_shapes=None,
max_workspace_size=1,
fp16_mode=True,
strict_type_constraints=False,
int8_mode=False,
int8_calibrator=None,
fp8_mode=True,
):
logger = trt.Logger(getattr(trt.Logger, log_level))
builder = trt.Builder(logger)
network = builder.create_network(
1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
if isinstance(model, str):
with open(model, 'rb') as f:
flag = parser.parse(f.read())
else:
flag = parser.parse(model.read())
if not flag:
for error in range(parser.num_errors):
print(parser.get_error(error))
# re-order output tensor
output_tensors = [network.get_output(i)
for i in range(network.num_outputs)]
[network.unmark_output(tensor) for tensor in output_tensors]
for tensor in output_tensors:
identity_out_tensor = network.add_identity(tensor).get_output(0)
identity_out_tensor.name = 'identity_{}'.format(tensor.name)
network.mark_output(tensor=identity_out_tensor)
if int(trt_version[0]) < 10:
builder.max_batch_size = max_batch_size
config = builder.create_builder_config()
if int(trt_version[0]) < 10:
config.max_workspace_size = max_workspace_size * (1 << 32)
if fp16_mode:
config.set_flag(trt.BuilderFlag.FP16)
if int8_mode:
config.set_flag(trt.BuilderFlag.INT8)
if int8_calibrator is None:
shapes = [(1,) + network.get_input(i).shape[1:]
for i in range(network.num_inputs)]
dummy_data = gen_ones(shapes)
int8_calibrator = EntropyCalibrator2(CustomDataset(dummy_data))
config.int8_calibrator = int8_calibrator
if fp8_mode:
config.set_flag(trt.BuilderFlag.FP8)
if strict_type_constraints:
config.set_flag(trt.BuilderFlag.STRICT_TYPES)
# set dynamic shape profile
assert not (bool(min_input_shapes) ^ bool(max_input_shapes))
profile = builder.create_optimization_profile()
input_shapes = [network.get_input(i).shape[1:]
for i in range(network.num_inputs)]
if not min_input_shapes:
min_input_shapes = input_shapes
if not max_input_shapes:
max_input_shapes = input_shapes
assert len(min_input_shapes) == len(max_input_shapes) == len(input_shapes)
for i in range(network.num_inputs):
tensor = network.get_input(i)
name = tensor.name
min_shape = (1,) + min_input_shapes[i]
max_shape = (max_batch_size,) + max_input_shapes[i]
opt_shape = [(min_ + max_) // 2
for min_, max_ in zip(min_shape, max_shape)]
profile.set_shape(name, min_shape, opt_shape, max_shape)
config.add_optimization_profile(profile)
if int(trt_version[0]) < 10:
engine = builder.build_engine(network, config)
else:
engine = builder.build_serialized_network(network, config)
return engine
coderabbitai and Graham1025