Skip to content

Commit d108f87

Browse files
committed
chore: Refactor fx2trt functionality
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 6cbf600 commit d108f87

File tree

1 file changed

+17
-26
lines changed

1 file changed

+17
-26
lines changed

tools/perf/perf_run.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
# Importing supported Backends
1616
import torch
1717
import torch_tensorrt as torchtrt
18-
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
19-
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter
20-
from torch_tensorrt.fx import TRTModule
18+
# import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
19+
# from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter
20+
# from torch_tensorrt.fx import TRTModule
21+
from torch_tensorrt.fx.lower import lower_to_trt
22+
from torch_tensorrt.fx.utils import LowerPrecision
23+
2124
import tensorrt as trt
2225
from utils import parse_inputs, parse_backends, precision_to_dtype, BENCHMARK_MODELS
2326

@@ -113,30 +116,18 @@ def run_torch_tensorrt(model, input_tensors, params, precision, truncate_long_an
113116
# Runs inference using FX2TRT backend
114117
def run_fx2trt(model, input_tensors, params, precision, batch_size):
115118
print("Running FX2TRT for precision: ", precision)
116-
117-
# Trace the model with acc_tracer.
118-
acc_mod = acc_tracer.trace(model, input_tensors)
119-
# Generate input specs
120-
input_specs = InputTensorSpec.from_tensors(input_tensors)
121-
# Build a TRT interpreter. Set explicit_batch_dimension accordingly.
122-
interpreter = TRTInterpreter(
123-
acc_mod, input_specs, explicit_batch_dimension=True
119+
if precision == "fp32":
120+
precision = LowerPrecision.FP32
121+
elif precision == "fp16":
122+
precision = LowerPrecision.FP16
123+
# Run lowering eager mode benchmark
124+
model = lower_to_trt(
125+
model,
126+
input_tensors,
127+
max_batch_size=batch_size,
128+
lower_precision=precision,
129+
verbose_log=True,
124130
)
125-
trt_interpreter_result = interpreter.run(
126-
max_batch_size=batch_size,
127-
lower_precision=precision,
128-
max_workspace_size=1 << 25,
129-
sparse_weights=False,
130-
force_fp32_output=False,
131-
strict_type_constraints=False,
132-
algorithm_selector=None,
133-
timing_cache=None,
134-
profiling_verbosity=None)
135-
136-
model = TRTModule(
137-
trt_interpreter_result.engine,
138-
trt_interpreter_result.input_names,
139-
trt_interpreter_result.output_names)
140131

141132
iters = params.get('iterations', 20)
142133
# Warm up

0 commit comments

Comments
 (0)