|
15 | 15 | # Importing supported Backends
|
16 | 16 | import torch
|
17 | 17 | import torch_tensorrt as torchtrt
|
18 |
| -import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer |
19 |
| -from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter |
20 |
| -from torch_tensorrt.fx import TRTModule |
| 18 | +# import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer |
| 19 | +# from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter |
| 20 | +# from torch_tensorrt.fx import TRTModule |
| 21 | +from torch_tensorrt.fx.lower import lower_to_trt |
| 22 | +from torch_tensorrt.fx.utils import LowerPrecision |
| 23 | + |
21 | 24 | import tensorrt as trt
|
22 | 25 | from utils import parse_inputs, parse_backends, precision_to_dtype, BENCHMARK_MODELS
|
23 | 26 |
|
@@ -113,30 +116,18 @@ def run_torch_tensorrt(model, input_tensors, params, precision, truncate_long_an
|
113 | 116 | # Runs inference using FX2TRT backend
|
114 | 117 | def run_fx2trt(model, input_tensors, params, precision, batch_size):
|
115 | 118 | print("Running FX2TRT for precision: ", precision)
|
116 |
| - |
117 |
| - # Trace the model with acc_tracer. |
118 |
| - acc_mod = acc_tracer.trace(model, input_tensors) |
119 |
| - # Generate input specs |
120 |
| - input_specs = InputTensorSpec.from_tensors(input_tensors) |
121 |
| - # Build a TRT interpreter. Set explicit_batch_dimension accordingly. |
122 |
| - interpreter = TRTInterpreter( |
123 |
| - acc_mod, input_specs, explicit_batch_dimension=True |
| 119 | + if precision == "fp32": |
| 120 | + precision = LowerPrecision.FP32 |
| 121 | + elif precision == "fp16": |
| 122 | + precision = LowerPrecision.FP16 |
| 123 | + # Run lowering eager mode benchmark |
| 124 | + model = lower_to_trt( |
| 125 | + model, |
| 126 | + input_tensors, |
| 127 | + max_batch_size=batch_size, |
| 128 | + lower_precision=precision, |
| 129 | + verbose_log=True, |
124 | 130 | )
|
125 |
| - trt_interpreter_result = interpreter.run( |
126 |
| - max_batch_size=batch_size, |
127 |
| - lower_precision=precision, |
128 |
| - max_workspace_size=1 << 25, |
129 |
| - sparse_weights=False, |
130 |
| - force_fp32_output=False, |
131 |
| - strict_type_constraints=False, |
132 |
| - algorithm_selector=None, |
133 |
| - timing_cache=None, |
134 |
| - profiling_verbosity=None) |
135 |
| - |
136 |
| - model = TRTModule( |
137 |
| - trt_interpreter_result.engine, |
138 |
| - trt_interpreter_result.input_names, |
139 |
| - trt_interpreter_result.output_names) |
140 | 131 |
|
141 | 132 | iters = params.get('iterations', 20)
|
142 | 133 | # Warm up
|
|
0 commit comments