@@ -293,29 +293,30 @@ def run_tensorrt(
293
293
input_tensors,
294
294
params,
295
295
precision,
296
- is_trt_engine=False,
297
296
batch_size=1,
298
297
):
299
- engine = None
300
-
301
- # If the model file is a TensorRT engine then directly deserialize and run inference
302
- # else convert the torch module to a TensorRT engine first and then run inference
303
- if not is_trt_engine:
304
- compile_settings = {
305
- "inputs": input_tensors,
306
- "enabled_precisions": {precision_to_dtype(precision)},
307
- "truncate_long_and_double": params.get("truncate", False),
308
- }
309
-
310
- print("Converting method to TensorRT engine...")
311
- with torch.no_grad(), torchtrt.logging.errors():
312
- model = torchtrt.ts.convert_method_to_trt_engine(
313
- model, "forward", **compile_settings
314
- )
315
-
298
+ # Export an ONNX model and convert to TRT
299
+ torch.onnx.export(model.eval().cuda(), tuple(input_tensors), "./tmp.onnx")
300
+ logger = trt.Logger(trt.Logger.WARNING)
301
+ builder = trt.Builder(logger)
302
+ network = builder.create_network(
303
+ 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
304
+ )
305
+ parser = trt.OnnxParser(network, logger)
306
+ success = parser.parse_from_file("./tmp.onnx")
307
+ if not success:
308
+ raise ValueError("ONNX conversion failed")
309
+
310
+ config = builder.create_builder_config()
311
+ if precision == "fp16":
312
+ config.set_flag(trt.BuilderFlag.FP16)
313
+ start_compile = time.time_ns()
314
+ serialized_engine = builder.build_serialized_network(network, config)
315
+ end_compile = time.time_ns()
316
+ compile_time_s = (end_compile - start_compile) / 1e9
316
317
# Deserialize the TensorRT engine
317
- with trt.Logger() as logger, trt. Runtime(logger) as runtime:
318
- engine = runtime.deserialize_cuda_engine(model )
318
+ with trt.Runtime(logger) as runtime:
319
+ engine = runtime.deserialize_cuda_engine(serialized_engine )
319
320
320
321
print("Running TensorRT for precision: ", precision, " batch_size : ", batch_size)
321
322
iters = params.get("iterations", 20)
@@ -350,7 +351,7 @@ def run_tensorrt(
350
351
meas_time = end_time - start_time
351
352
timings.append(meas_time)
352
353
353
- recordStats("TensorRT", timings, precision, batch_size)
354
+ recordStats("TensorRT", timings, precision, batch_size, compile_time_s )
354
355
355
356
356
357
# Deploys inference run for different backend configurations
@@ -426,11 +427,10 @@ def run(
426
427
)
427
428
elif backend == "tensorrt":
428
429
run_tensorrt(
429
- model ,
430
+ model_torch ,
430
431
input_tensors,
431
432
params,
432
433
precision,
433
- is_trt_engine,
434
434
batch_size,
435
435
)
436
436
elif backend == "dynamo":
@@ -439,9 +439,6 @@ def run(
439
439
elif backend == "torch_compile":
440
440
run_torch_compile(model_torch, input_tensors, params, precision, batch_size)
441
441
442
- elif backend == "torch_compile":
443
- run_torch_compile(model_torch, input_tensors, params, precision, batch_size)
444
-
445
442
elif backend == "inductor":
446
443
run_inductor(model_torch, input_tensors, params, precision, batch_size)
447
444
0 commit comments