Skip to content

Commit c692074

Browse files
authored
Update onnx ptq test to be single threaded and make it faster (#415)
Signed-off-by: ajrasane <[email protected]>
1 parent f8a9353 commit c692074

File tree

5 files changed

+82
-91
lines changed

5 files changed

+82
-91
lines changed

examples/onnx_ptq/evaluate.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,15 @@ def main():
3838
parser.add_argument(
3939
"--engine_path",
4040
type=str,
41-
required=True,
41+
default=None,
4242
help="Path to the TensorRT engine",
4343
)
44+
parser.add_argument(
45+
"--timing_cache_path",
46+
type=str,
47+
default=None,
48+
help="Path to the TensorRT timing cache",
49+
)
4450
parser.add_argument(
4551
"--imagenet_path", type=str, default=None, help="Path to the imagenet dataset"
4652
)
@@ -81,6 +87,7 @@ def main():
8187
# Compile the ONNX model to TRT engine and create the device model
8288
compilation_args = {
8389
"engine_path": args.engine_path,
90+
"timing_cache_path": args.timing_cache_path,
8491
}
8592
compiled_model = client.ir_to_compiled(onnx_bytes, compilation_args)
8693
device_model = DeviceModel(client, compiled_model, metadata={})

modelopt/torch/_deploy/_runtime/tensorrt/engine_builder.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def build_engine(
123123
onnx_bytes: OnnxBytes,
124124
trt_mode: str = TRTMode.FLOAT32,
125125
engine_path: Path | None = None,
126+
timing_cache_path: Path | None = None,
126127
calib_cache: str | None = None,
127128
dynamic_shapes: dict | None = None,
128129
plugin_config: dict | None = None,
@@ -135,6 +136,7 @@ def build_engine(
135136
Args:
136137
onnx_bytes: Data of the ONNX model stored as an OnnxBytes object.
137138
engine_path: Path to save the TensorRT engine.
139+
timing_cache_path: Path to save/load the TensorRT timing cache.
138140
trt_mode: The precision with which the TensorRT engine will be built. Supported modes are:
139141
- TRTMode.FLOAT32
140142
- TRTMode.FLOAT16
@@ -205,6 +207,7 @@ def _build_command(
205207
def _setup_files_and_paths(
206208
tmp_dir_path: Path,
207209
engine_path: Path | None,
210+
timing_cache_path: Path | None,
208211
) -> tuple[Path, Path, Path | None, Path | None, Path]:
209212
tmp_onnx_dir = tmp_dir_path / "onnx"
210213
onnx_bytes.write_to_disk(str(tmp_onnx_dir))
@@ -219,13 +222,15 @@ def _setup_files_and_paths(
219222
)
220223
engine_path.parent.mkdir(parents=True, exist_ok=True)
221224
calib_cache_path = final_output_dir / "calib_cache" if calib_cache else None
222-
timing_cache_path = final_output_dir / "timing.cache"
225+
timing_cache_path = (
226+
Path(timing_cache_path) if timing_cache_path else final_output_dir / "timing.cache"
227+
)
223228

224229
return onnx_path, engine_path, calib_cache_path, timing_cache_path, final_output_dir
225230

226231
with TemporaryDirectory() as tmp_dir:
227232
onnx_path, engine_path, calib_cache_path, timing_cache_path, final_output_dir = (
228-
_setup_files_and_paths(Path(tmp_dir), engine_path)
233+
_setup_files_and_paths(Path(tmp_dir), engine_path, timing_cache_path)
229234
)
230235
cmd = _build_command(onnx_path, engine_path, calib_cache_path, timing_cache_path)
231236

modelopt/torch/_deploy/_runtime/trt_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def _ir_to_compiled(
7474
Args:
7575
ir_bytes: The ONNX model bytes.
7676
compilation_args: A dictionary of compilation arguments.
77-
The following arguments are supported: dynamic_shapes, plugin_config, engine_path.
77+
The following arguments are supported: dynamic_shapes, plugin_config, engine_path, timing_cache_path.
7878
7979
Returns:
8080
The compiled TRT engine bytes.
@@ -87,6 +87,7 @@ def _ir_to_compiled(
8787
dynamic_shapes=compilation_args.get("dynamic_shapes"), # type: ignore[union-attr]
8888
plugin_config=compilation_args.get("plugin_config"), # type: ignore[union-attr]
8989
engine_path=compilation_args.get("engine_path"), # type: ignore[union-attr]
90+
timing_cache_path=compilation_args.get("timing_cache_path"), # type: ignore[union-attr]
9091
trt_mode=self.deployment["precision"],
9192
verbose=(self.deployment.get("verbose", "false").lower() == "true"),
9293
)

tests/_test_utils/onnx_quantization/lib_test_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def export_as_onnx(
124124
output_names=output_names,
125125
opset_version=opset,
126126
do_constant_folding=do_constant_folding,
127+
dynamo=False,
127128
)
128129

129130

tests/examples/test_onnx_ptq.sh

Lines changed: 64 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# It is recommended to execute this script inside the Model Optimization Toolkit TensorRT Docker container.
2222
# Please ensure that the ImageNet dataset is available in the container at the specified path.
2323

24-
# Usage: ./test_onnx_ptq.sh [--no-clean] [/path/to/imagenet] [/path/to/models]
24+
# Usage: ./test_onnx_ptq.sh [--no-clean] [--eval] [/path/to/imagenet] [/path/to/models] [/path/to/timing_cache]
2525

2626
set -exo pipefail
2727

@@ -37,20 +37,28 @@ pushd $public_example_dir
3737

3838
# Parse arguments
3939
clean_mode=true
40+
eval_mode=false
4041
imagenet_path=""
4142
models_folder=""
43+
timing_cache_path=""
4244

4345
for arg in "$@"; do
4446
case $arg in
4547
--no-clean)
4648
clean_mode=false
4749
shift
4850
;;
51+
--eval)
52+
eval_mode=true
53+
shift
54+
;;
4955
*)
5056
if [ -z "$imagenet_path" ]; then
5157
imagenet_path="$arg"
5258
elif [ -z "$models_folder" ]; then
5359
models_folder="$arg"
60+
elif [ -z "$timing_cache_path" ]; then
61+
timing_cache_path="$arg"
5462
fi
5563
shift
5664
;;
@@ -63,7 +71,9 @@ export TQDM_DISABLE=1
6371
# Setting image and model paths (contains 8 models)
6472
imagenet_path=${imagenet_path:-/data/imagenet/}
6573
models_folder=${models_folder:-/models/onnx}
66-
calib_size=64
74+
timing_cache_path=${timing_cache_path:-/models/onnx/build/timing.cache}
75+
calib_size=1
76+
eval_size=100
6777
batch_size=1
6878

6979

@@ -137,117 +147,84 @@ for model_path in "${model_paths[@]}"; do
137147
model_name=$(basename "$model_path" .onnx)
138148
model_dir=build/$model_name
139149

140-
141-
echo "Quantizing model $model_name for all quantization modes in parallel"
142-
pids=()
143-
for i in "${!quant_modes[@]}"; do
144-
quant_mode="${quant_modes[$i]}"
145-
gpu_id=$((i % nvidia_gpu_count))
150+
echo "Quantizing model $model_name for all quantization modes"
151+
for quant_mode in "${quant_modes[@]}"; do
146152
if [ "$quant_mode" == "int8_iq" ]; then
147153
continue
148154
fi
149155

150-
echo "Starting quantization of $model_name for mode: $quant_mode on GPU $gpu_id"
151-
CUDA_VISIBLE_DEVICES=$gpu_id python -m modelopt.onnx.quantization \
156+
echo "Starting quantization of $model_name for mode: $quant_mode"
157+
python -m modelopt.onnx.quantization \
152158
--onnx_path=$model_dir/fp16/model.onnx \
153159
--quantize_mode=$quant_mode \
154160
--calibration_data=$calib_data_path \
155161
--output_path=$model_dir/$quant_mode/model.quant.onnx \
156-
--calibration_eps=cuda:0 &
157-
pids+=($!)
158-
done
159-
160-
# Wait for all quantization processes to complete for this model
161-
error_occurred=false
162-
for pid in "${pids[@]}"; do
163-
if ! wait $pid; then
164-
echo "ERROR: Quantization process (PID: $pid) failed"
165-
error_occurred=true
166-
fi
162+
--calibration_eps=cuda
167163
done
168-
if [ "$error_occurred" = true ]; then
169-
echo "Stopping execution due to quantization failure for model: $model_name"
170-
exit 1
171-
fi
172164

173165
echo "Completed quantization of all modes for model: $model_name"
174166
done
175167

176168

177169
# Evaluate the quantized models for each mode
178-
for model_path in "${model_paths[@]}"; do
179-
model_name=$(basename "$model_path" .onnx)
180-
model_dir=build/$model_name
181-
182-
echo "Evaluating model $model_name for all quantization modes in parallel"
183-
pids=()
184-
for i in "${!all_modes[@]}"; do
185-
quant_mode="${all_modes[$i]}"
186-
gpu_id=$((i % nvidia_gpu_count))
187-
188-
if [ "$quant_mode" == "fp16" ]; then
189-
eval_model_path=$model_dir/fp16/model.onnx
190-
engine_path=$model_dir/fp16/model.engine
191-
precision="fp16"
192-
elif [ "$quant_mode" == "int8_iq" ]; then
193-
eval_model_path=$model_dir/fp16/model.onnx
194-
engine_path=$model_dir/int8_iq/model.engine
195-
precision="best"
196-
else
197-
eval_model_path=$model_dir/$quant_mode/model.quant.onnx
198-
engine_path=$model_dir/$quant_mode/model.quant.engine
199-
precision="stronglyTyped"
200-
fi
170+
if [ "$eval_mode" = true ]; then
171+
for model_path in "${model_paths[@]}"; do
172+
model_name=$(basename "$model_path" .onnx)
173+
model_dir=build/$model_name
174+
175+
echo "Evaluating model $model_name for all quantization modes"
176+
for quant_mode in "${all_modes[@]}"; do
177+
if [ "$quant_mode" == "fp16" ]; then
178+
eval_model_path=$model_dir/fp16/model.onnx
179+
engine_path=$model_dir/fp16/model.engine
180+
precision="fp16"
181+
elif [ "$quant_mode" == "int8_iq" ]; then
182+
eval_model_path=$model_dir/fp16/model.onnx
183+
engine_path=$model_dir/int8_iq/model.engine
184+
precision="best"
185+
else
186+
eval_model_path=$model_dir/$quant_mode/model.quant.onnx
187+
engine_path=$model_dir/$quant_mode/model.quant.engine
188+
precision="stronglyTyped"
189+
fi
201190

202-
echo "Starting evaluation of $model_name for mode: $quant_mode on GPU $gpu_id"
203-
if [[ " ${latency_models[@]} " =~ " $model_name " ]]; then
204-
CUDA_VISIBLE_DEVICES=$gpu_id python evaluate.py \
205-
--onnx_path=$eval_model_path \
206-
--engine_path=$engine_path \
207-
--model_name="${timm_model_name[$model_name]}" \
208-
--engine_precision=$precision \
209-
--results_path=$model_dir/$quant_mode/${model_name}_${quant_mode}.csv &
210-
else
211-
CUDA_VISIBLE_DEVICES=$gpu_id python evaluate.py \
212-
--onnx_path=$eval_model_path \
213-
--engine_path=$engine_path \
214-
--imagenet_path=$imagenet_path \
215-
--eval_data_size=$calib_size \
216-
--batch_size $batch_size \
217-
--model_name="${timm_model_name[$model_name]}" \
218-
--engine_precision=$precision \
219-
--results_path=$model_dir/$quant_mode/${model_name}_${quant_mode}.csv &
220-
fi
221-
pids+=($!)
222-
done
191+
echo "Starting evaluation of $model_name for mode: $quant_mode"
192+
if [[ " ${latency_models[@]} " =~ " $model_name " ]]; then
193+
python evaluate.py \
194+
--onnx_path=$eval_model_path \
195+
--engine_path=$engine_path \
196+
--model_name="${timm_model_name[$model_name]}" \
197+
--engine_precision=$precision \
198+
--results_path=$model_dir/$quant_mode/${model_name}_${quant_mode}.csv \
199+
--timing_cache_path=$timing_cache_path
200+
else
201+
python evaluate.py \
202+
--onnx_path=$eval_model_path \
203+
--engine_path=$engine_path \
204+
--imagenet_path=$imagenet_path \
205+
--eval_data_size=$eval_size \
206+
--batch_size $batch_size \
207+
--model_name="${timm_model_name[$model_name]}" \
208+
--engine_precision=$precision \
209+
--results_path=$model_dir/$quant_mode/${model_name}_${quant_mode}.csv \
210+
--timing_cache_path=$timing_cache_path
211+
fi
212+
done
223213

224-
# Wait for all evaluation processes to complete for this model
225-
error_occurred=false
226-
for pid in "${pids[@]}"; do
227-
if ! wait $pid; then
228-
echo "ERROR: Evaluation process (PID: $pid) failed"
229-
error_occurred=true
230-
fi
214+
echo "Completed evaluation of all modes for model: $model_name"
231215
done
232-
if [ "$error_occurred" = true ]; then
233-
echo "Stopping execution due to evaluation failure for model: $model_name"
234-
exit 1
235-
fi
236-
237-
echo "Completed evaluation of all modes for model: $model_name"
238-
done
239216

240-
python $test_utils_dir/aggregate_results.py --results_dir=build
217+
python $test_utils_dir/aggregate_results.py --results_dir=build
218+
fi
241219

242220
if [ "$clean_mode" = true ]; then
243221
echo "Cleaning build artifacts..."
244222
rm -rf build/
245223
echo "Build artifacts cleaned successfully."
246-
popd
247-
exit 0
248224
fi
249225

250226
popd
251227

252228

253-
echo "Total wall time: $(($(date +%s) - start_time)) seconds"
229+
total_seconds=$(($(date +%s) - start_time))
230+
printf "Total wall time: %02d:%02d:%02d\n" $((total_seconds/3600)) $(((total_seconds%3600)/60)) $((total_seconds%60))

0 commit comments

Comments
 (0)