Skip to content

Commit 67305e4

Browse files
committed
Add timing cache to the evaluate API
Signed-off-by: ajrasane <[email protected]>
1 parent 685183d commit 67305e4

File tree

4 files changed

+24
-8
lines changed

4 files changed

+24
-8
lines changed

examples/onnx_ptq/evaluate.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,15 @@ def main():
3838
parser.add_argument(
3939
"--engine_path",
4040
type=str,
41-
required=True,
41+
default=None,
4242
help="Path to the TensorRT engine",
4343
)
44+
parser.add_argument(
45+
"--timing_cache_path",
46+
type=str,
47+
default=None,
48+
help="Path to the TensorRT timing cache",
49+
)
4450
parser.add_argument(
4551
"--imagenet_path", type=str, default=None, help="Path to the imagenet dataset"
4652
)
@@ -81,6 +87,7 @@ def main():
8187
# Compile the ONNX model to TRT engine and create the device model
8288
compilation_args = {
8389
"engine_path": args.engine_path,
90+
"timing_cache_path": args.timing_cache_path,
8491
}
8592
compiled_model = client.ir_to_compiled(onnx_bytes, compilation_args)
8693
device_model = DeviceModel(client, compiled_model, metadata={})

modelopt/torch/_deploy/_runtime/tensorrt/engine_builder.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def build_engine(
123123
onnx_bytes: OnnxBytes,
124124
trt_mode: str = TRTMode.FLOAT32,
125125
engine_path: Path | None = None,
126+
timing_cache_path: Path | None = None,
126127
calib_cache: str | None = None,
127128
dynamic_shapes: dict | None = None,
128129
plugin_config: dict | None = None,
@@ -135,6 +136,7 @@ def build_engine(
135136
Args:
136137
onnx_bytes: Data of the ONNX model stored as an OnnxBytes object.
137138
engine_path: Path to save the TensorRT engine.
139+
timing_cache_path: Path to save/load the TensorRT timing cache.
138140
trt_mode: The precision with which the TensorRT engine will be built. Supported modes are:
139141
- TRTMode.FLOAT32
140142
- TRTMode.FLOAT16
@@ -205,6 +207,7 @@ def _build_command(
205207
def _setup_files_and_paths(
206208
tmp_dir_path: Path,
207209
engine_path: Path | None,
210+
timing_cache_path: Path | None,
208211
) -> tuple[Path, Path, Path | None, Path | None, Path]:
209212
tmp_onnx_dir = tmp_dir_path / "onnx"
210213
onnx_bytes.write_to_disk(str(tmp_onnx_dir))
@@ -219,13 +222,15 @@ def _setup_files_and_paths(
219222
)
220223
engine_path.parent.mkdir(parents=True, exist_ok=True)
221224
calib_cache_path = final_output_dir / "calib_cache" if calib_cache else None
222-
timing_cache_path = final_output_dir / "timing.cache"
225+
timing_cache_path = (
226+
Path(timing_cache_path) if timing_cache_path else final_output_dir / "timing.cache"
227+
)
223228

224229
return onnx_path, engine_path, calib_cache_path, timing_cache_path, final_output_dir
225230

226231
with TemporaryDirectory() as tmp_dir:
227232
onnx_path, engine_path, calib_cache_path, timing_cache_path, final_output_dir = (
228-
_setup_files_and_paths(Path(tmp_dir), engine_path)
233+
_setup_files_and_paths(Path(tmp_dir), engine_path, timing_cache_path)
229234
)
230235
cmd = _build_command(onnx_path, engine_path, calib_cache_path, timing_cache_path)
231236

modelopt/torch/_deploy/_runtime/trt_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def _ir_to_compiled(
7474
Args:
7575
ir_bytes: The ONNX model bytes.
7676
compilation_args: A dictionary of compilation arguments.
77-
The following arguments are supported: dynamic_shapes, plugin_config, engine_path.
77+
The following arguments are supported: dynamic_shapes, plugin_config, engine_path, timing_cache_path.
7878
7979
Returns:
8080
The compiled TRT engine bytes.
@@ -87,6 +87,7 @@ def _ir_to_compiled(
8787
dynamic_shapes=compilation_args.get("dynamic_shapes"), # type: ignore[union-attr]
8888
plugin_config=compilation_args.get("plugin_config"), # type: ignore[union-attr]
8989
engine_path=compilation_args.get("engine_path"), # type: ignore[union-attr]
90+
timing_cache_path=compilation_args.get("timing_cache_path"), # type: ignore[union-attr]
9091
trt_mode=self.deployment["precision"],
9192
verbose=(self.deployment.get("verbose", "false").lower() == "true"),
9293
)

tests/examples/test_onnx_ptq.sh

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ for model_path in "${model_paths[@]}"; do
149149
--quantize_mode=$quant_mode \
150150
--calibration_data=$calib_data_path \
151151
--output_path=$model_dir/$quant_mode/model.quant.onnx \
152-
--calibration_eps=cuda:0
152+
--calibration_eps=cuda
153153
done
154154

155155
echo "Completed quantization of all modes for model: $model_name"
@@ -184,7 +184,8 @@ for model_path in "${model_paths[@]}"; do
184184
--engine_path=$engine_path \
185185
--model_name="${timm_model_name[$model_name]}" \
186186
--engine_precision=$precision \
187-
--results_path=$model_dir/$quant_mode/${model_name}_${quant_mode}.csv
187+
--results_path=$model_dir/$quant_mode/${model_name}_${quant_mode}.csv \
188+
--timing_cache_path=build/timing.cache
188189
else
189190
python evaluate.py \
190191
--onnx_path=$eval_model_path \
@@ -194,7 +195,8 @@ for model_path in "${model_paths[@]}"; do
194195
--batch_size $batch_size \
195196
--model_name="${timm_model_name[$model_name]}" \
196197
--engine_precision=$precision \
197-
--results_path=$model_dir/$quant_mode/${model_name}_${quant_mode}.csv
198+
--results_path=$model_dir/$quant_mode/${model_name}_${quant_mode}.csv \
199+
--timing_cache_path=build/timing.cache
198200
fi
199201
done
200202

@@ -214,4 +216,5 @@ fi
214216
popd
215217

216218

217-
echo "Total wall time: $(($(date +%s) - start_time)) seconds"
219+
total_seconds=$(($(date +%s) - start_time))
220+
printf "Total wall time: %02d:%02d:%02d\n" $((total_seconds/3600)) $(((total_seconds%3600)/60)) $((total_seconds%60))

0 commit comments

Comments
 (0)