Skip to content

Commit a9489e9

Browse files
ajrasanekevalmorabia97
authored andcommitted
[5256037] Automatically infer calibration shapes for per node calibration (#394)
Signed-off-by: ajrasane <[email protected]>
1 parent d56f03c commit a9489e9

File tree

6 files changed

+28
-7
lines changed

6 files changed

+28
-7
lines changed

examples/onnx_ptq/README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,7 @@ python -m modelopt.onnx.quantization \
215215
--quantize_mode=<int8/fp8> \
216216
--calibration_data=calib.npy \
217217
--calibrate_per_node \
218-
--output_path=vit_base_patch16_224.quant.onnx \
219-
--calibration_shapes=input:1x3x224x224
218+
--output_path=vit_base_patch16_224.quant.onnx
220219
```
221220

222221
> **Note**: Per node calibration is not available for INT4 quantization methods (`awq_clip`, `rtn_dq`)

modelopt/onnx/quantization/fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def quantize(
164164
calibration_method: str = "entropy",
165165
calibration_data_reader: CalibrationDataReader = None,
166166
calibration_cache_path: str | None = None,
167-
calibration_shapes: str | None = None,
167+
calibration_shapes: str | dict | None = None,
168168
calibration_eps: list[str] = ["cpu", "cuda:0", "trt"],
169169
op_types_to_quantize: list[str] | None = None,
170170
op_types_to_exclude: list[str] | None = None,

modelopt/onnx/quantization/graph_utils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,15 @@ def has_const_input(node: Node) -> bool:
7171
return any(is_const_input(tensor) for tensor in node.inputs)
7272

7373

74+
def get_input_shapes(onnx_path: str) -> dict[str, list[int]]:
75+
"""Returns the input shapes of the given ONNX model."""
76+
onnx_model = onnx.load(onnx_path)
77+
input_shape_dict = {}
78+
for input in onnx_model.graph.input:
79+
input_shape_dict[input.name] = [x.dim_value for x in input.type.tensor_type.shape.dim]
80+
return input_shape_dict
81+
82+
7483
def has_path_type(
7584
node: Node,
7685
graph: Graph,
@@ -707,7 +716,7 @@ def find_nodes_from_matmul_to_exclude(
707716
intermediate_generated_files: list[str] | None = None,
708717
calibration_data_reader: CalibrationDataReader = None,
709718
calibration_eps: list[str] = ["cpu", "cuda:0", "trt"],
710-
calibration_shapes: str | None = None,
719+
calibration_shapes: str | dict | None = None,
711720
) -> list[str]:
712721
"""Find MatMul nodes that meets gemv condition to exclude.
713722
@@ -834,7 +843,7 @@ def find_nodes_from_convs_to_exclude(graph: Graph, quantize_mode: str = "int8"):
834843

835844

836845
def _exclude_matmuls_by_symbolic_inference(
837-
model: onnx.ModelProto, matmul_nodes: list, calibration_shapes: str | None = None
846+
model: onnx.ModelProto, matmul_nodes: list, calibration_shapes: str | dict | None = None
838847
) -> list[str]:
839848
"""Use symbolic shape inference to find MatMuls with dimension 1."""
840849
# Prepare model for symbolic inference
@@ -845,7 +854,11 @@ def _exclude_matmuls_by_symbolic_inference(
845854
dim.dim_value = 1
846855

847856
# Apply calibration shapes if provided
848-
input_shapes = parse_shapes_spec(calibration_shapes) if calibration_shapes else {}
857+
input_shapes = (
858+
parse_shapes_spec(calibration_shapes)
859+
if (calibration_shapes and isinstance(calibration_shapes, str))
860+
else {}
861+
)
849862
for graph_input in model.graph.input:
850863
if graph_input.name in input_shapes:
851864
input_shape = input_shapes[graph_input.name]

modelopt/onnx/quantization/int8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def quantize(
115115
calibration_method: str = "entropy",
116116
calibration_data_reader: CalibrationDataReader = None,
117117
calibration_cache_path: str | None = None,
118-
calibration_shapes: str | None = None,
118+
calibration_shapes: str | dict | None = None,
119119
calibration_eps: list[str] = ["cpu", "cuda:0", "trt"],
120120
op_types_to_quantize: list[str] | None = None,
121121
op_types_to_exclude: list[str] | None = None,

modelopt/onnx/quantization/quantize.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from modelopt.onnx.quantization.graph_utils import (
5454
cast_custom_ops,
5555
find_nodes_from_mha_to_exclude,
56+
get_input_shapes,
5657
print_stat,
5758
remove_redundant_cast_nodes,
5859
validate_op_types_spelling,
@@ -251,6 +252,8 @@ def quantize(
251252
Path to pre-calculated activation tensor ranges, also known as calibration cache.
252253
calibration_shapes:
253254
Input shapes used for calibration process.
255+
It should be provided as a string representing the shape of each input tensors for one calibration step.
256+
Example input shapes spec: input0:1x3x256x256,input1:1x3x128x128
254257
calibration_eps:
255258
Priority order for the execution providers (EP) to calibrate the model.
256259
Any subset of ['NvTensorRtRtx', 'trt', 'cuda:x', 'dml:x', 'cpu'], where 'x' is the device id.
@@ -463,6 +466,9 @@ def quantize(
463466
calibration_eps,
464467
)
465468

469+
if not calibration_shapes:
470+
calibration_shapes = get_input_shapes(onnx_path)
471+
466472
if quantize_mode in ["fp8", "int8"]:
467473
quantize_func = quantize_int8 if quantize_mode == "int8" else quantize_fp8
468474
onnx_model = quantize_func(

modelopt/onnx/trt_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,9 @@ def load_onnx_model(
266266
custom_ops = []
267267
has_custom_op = False
268268

269+
# Infer shapes
270+
onnx.shape_inference.infer_shapes_path(onnx_path)
271+
269272
# Load the model and weights
270273
onnx_model = onnx.load(onnx_path, load_external_data=True)
271274
size_threshold = 2 * (1024**3) # 2GB

0 commit comments

Comments
 (0)