Skip to content

Commit 0d51156

Browse files
authored
[5256037] Automatically infer calibration shapes for per node calibration (#394)
Signed-off-by: ajrasane <[email protected]>
1 parent 17439e6 commit 0d51156

File tree

6 files changed

+28
-7
lines changed

6 files changed

+28
-7
lines changed

examples/onnx_ptq/README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,7 @@ python -m modelopt.onnx.quantization \
215215
--quantize_mode=<int8/fp8> \
216216
--calibration_data=calib.npy \
217217
--calibrate_per_node \
218-
--output_path=vit_base_patch16_224.quant.onnx \
219-
--calibration_shapes=input:1x3x224x224
218+
--output_path=vit_base_patch16_224.quant.onnx
220219
```
221220

222221
> **Note**: Per node calibration is not available for INT4 quantization methods (`awq_clip`, `rtn_dq`)

modelopt/onnx/quantization/fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def quantize(
164164
calibration_method: str = "entropy",
165165
calibration_data_reader: CalibrationDataReader = None,
166166
calibration_cache_path: str | None = None,
167-
calibration_shapes: str | None = None,
167+
calibration_shapes: str | dict | None = None,
168168
calibration_eps: list[str] = ["cpu", "cuda:0", "trt"],
169169
op_types_to_quantize: list[str] | None = None,
170170
op_types_to_exclude: list[str] | None = None,

modelopt/onnx/quantization/graph_utils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,15 @@ def has_const_input(node: Node) -> bool:
7272
return any(is_const_input(tensor) for tensor in node.inputs)
7373

7474

75+
def get_input_shapes(onnx_path: str) -> dict[str, list[int]]:
76+
"""Returns the input shapes of the given ONNX model."""
77+
onnx_model = onnx.load(onnx_path)
78+
input_shape_dict = {}
79+
for input in onnx_model.graph.input:
80+
input_shape_dict[input.name] = [x.dim_value for x in input.type.tensor_type.shape.dim]
81+
return input_shape_dict
82+
83+
7584
def has_path_type(
7685
node: Node,
7786
graph: Graph,
@@ -923,7 +932,7 @@ def find_nodes_from_matmul_to_exclude(
923932
intermediate_generated_files: list[str] | None = None,
924933
calibration_data_reader: CalibrationDataReader = None,
925934
calibration_eps: list[str] = ["cpu", "cuda:0", "trt"],
926-
calibration_shapes: str | None = None,
935+
calibration_shapes: str | dict | None = None,
927936
) -> list[str]:
928937
"""Find MatMul nodes that meets gemv condition to exclude.
929938
@@ -1050,7 +1059,7 @@ def find_nodes_from_convs_to_exclude(graph: Graph, quantize_mode: str = "int8"):
10501059

10511060

10521061
def _exclude_matmuls_by_symbolic_inference(
1053-
model: onnx.ModelProto, matmul_nodes: list, calibration_shapes: str | None = None
1062+
model: onnx.ModelProto, matmul_nodes: list, calibration_shapes: str | dict | None = None
10541063
) -> list[str]:
10551064
"""Use symbolic shape inference to find MatMuls with dimension 1."""
10561065
# Prepare model for symbolic inference
@@ -1061,7 +1070,11 @@ def _exclude_matmuls_by_symbolic_inference(
10611070
dim.dim_value = 1
10621071

10631072
# Apply calibration shapes if provided
1064-
input_shapes = parse_shapes_spec(calibration_shapes) if calibration_shapes else {}
1073+
input_shapes = (
1074+
parse_shapes_spec(calibration_shapes)
1075+
if (calibration_shapes and isinstance(calibration_shapes, str))
1076+
else {}
1077+
)
10651078
for graph_input in model.graph.input:
10661079
if graph_input.name in input_shapes:
10671080
input_shape = input_shapes[graph_input.name]

modelopt/onnx/quantization/int8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def quantize(
115115
calibration_method: str = "entropy",
116116
calibration_data_reader: CalibrationDataReader = None,
117117
calibration_cache_path: str | None = None,
118-
calibration_shapes: str | None = None,
118+
calibration_shapes: str | dict | None = None,
119119
calibration_eps: list[str] = ["cpu", "cuda:0", "trt"],
120120
op_types_to_quantize: list[str] | None = None,
121121
op_types_to_exclude: list[str] | None = None,

modelopt/onnx/quantization/quantize.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from modelopt.onnx.quantization.graph_utils import (
5454
cast_custom_ops,
5555
find_nodes_from_mha_to_exclude,
56+
get_input_shapes,
5657
print_stat,
5758
remove_redundant_cast_nodes,
5859
validate_op_types_spelling,
@@ -255,6 +256,8 @@ def quantize(
255256
Path to pre-calculated activation tensor ranges, also known as calibration cache.
256257
calibration_shapes:
257258
Input shapes used for calibration process.
259+
It should be provided as a string representing the shape of each input tensors for one calibration step.
260+
Example input shapes spec: input0:1x3x256x256,input1:1x3x128x128
258261
calibration_eps:
259262
Priority order for the execution providers (EP) to calibrate the model.
260263
Any subset of ['NvTensorRtRtx', 'trt', 'cuda:x', 'dml:x', 'cpu'], where 'x' is the device id.
@@ -467,6 +470,9 @@ def quantize(
467470
calibration_eps,
468471
)
469472

473+
if not calibration_shapes:
474+
calibration_shapes = get_input_shapes(onnx_path)
475+
470476
if quantize_mode in ["fp8", "int8"]:
471477
quantize_func = quantize_int8 if quantize_mode == "int8" else quantize_fp8
472478
onnx_model = quantize_func(

modelopt/onnx/trt_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,9 @@ def load_onnx_model(
266266
custom_ops = []
267267
has_custom_op = False
268268

269+
# Infer shapes
270+
onnx.shape_inference.infer_shapes_path(onnx_path)
271+
269272
# Load the model and weights
270273
onnx_model = onnx.load(onnx_path, load_external_data=True)
271274
size_threshold = 2 * (1024**3) # 2GB

0 commit comments

Comments
 (0)