diff --git a/examples/onnx_ptq/README.md b/examples/onnx_ptq/README.md index 483dde359..012cbecf1 100644 --- a/examples/onnx_ptq/README.md +++ b/examples/onnx_ptq/README.md @@ -215,8 +215,7 @@ python -m modelopt.onnx.quantization \ --quantize_mode= \ --calibration_data=calib.npy \ --calibrate_per_node \ - --output_path=vit_base_patch16_224.quant.onnx \ - --calibration_shapes=input:1x3x224x224 + --output_path=vit_base_patch16_224.quant.onnx ``` > **Note**: Per node calibration is not available for INT4 quantization methods (`awq_clip`, `rtn_dq`) diff --git a/modelopt/onnx/quantization/fp8.py b/modelopt/onnx/quantization/fp8.py index 1ef3c9799..ce7d56a26 100755 --- a/modelopt/onnx/quantization/fp8.py +++ b/modelopt/onnx/quantization/fp8.py @@ -164,7 +164,7 @@ def quantize( calibration_method: str = "entropy", calibration_data_reader: CalibrationDataReader = None, calibration_cache_path: str | None = None, - calibration_shapes: str | None = None, + calibration_shapes: str | dict | None = None, calibration_eps: list[str] = ["cpu", "cuda:0", "trt"], op_types_to_quantize: list[str] | None = None, op_types_to_exclude: list[str] | None = None, diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index 8a5b8ad57..6b37e3e7e 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -72,6 +72,15 @@ def has_const_input(node: Node) -> bool: return any(is_const_input(tensor) for tensor in node.inputs) +def get_input_shapes(onnx_path: str) -> dict[str, list[int]]: + """Returns the input shapes of the given ONNX model.""" + onnx_model = onnx.load(onnx_path) + input_shape_dict = {} + for input in onnx_model.graph.input: + input_shape_dict[input.name] = [x.dim_value for x in input.type.tensor_type.shape.dim] + return input_shape_dict + + def has_path_type( node: Node, graph: Graph, @@ -923,7 +932,7 @@ def find_nodes_from_matmul_to_exclude( intermediate_generated_files: list[str] | None = None, calibration_data_reader: CalibrationDataReader = None, calibration_eps: list[str] = ["cpu", "cuda:0", "trt"], - calibration_shapes: str | None = None, + calibration_shapes: str | dict | None = None, ) -> list[str]: """Find MatMul nodes that meets gemv condition to exclude. @@ -1050,7 +1059,7 @@ def find_nodes_from_convs_to_exclude(graph: Graph, quantize_mode: str = "int8"): def _exclude_matmuls_by_symbolic_inference( - model: onnx.ModelProto, matmul_nodes: list, calibration_shapes: str | None = None + model: onnx.ModelProto, matmul_nodes: list, calibration_shapes: str | dict | None = None ) -> list[str]: """Use symbolic shape inference to find MatMuls with dimension 1.""" # Prepare model for symbolic inference @@ -1061,7 +1070,11 @@ def _exclude_matmuls_by_symbolic_inference( dim.dim_value = 1 # Apply calibration shapes if provided - input_shapes = parse_shapes_spec(calibration_shapes) if calibration_shapes else {} + input_shapes = ( + parse_shapes_spec(calibration_shapes) + if (calibration_shapes and isinstance(calibration_shapes, str)) + else {} + ) for graph_input in model.graph.input: if graph_input.name in input_shapes: input_shape = input_shapes[graph_input.name] diff --git a/modelopt/onnx/quantization/int8.py b/modelopt/onnx/quantization/int8.py index 5a878fb76..baf5a4383 100755 --- a/modelopt/onnx/quantization/int8.py +++ b/modelopt/onnx/quantization/int8.py @@ -115,7 +115,7 @@ def quantize( calibration_method: str = "entropy", calibration_data_reader: CalibrationDataReader = None, calibration_cache_path: str | None = None, - calibration_shapes: str | None = None, + calibration_shapes: str | dict | None = None, calibration_eps: list[str] = ["cpu", "cuda:0", "trt"], op_types_to_quantize: list[str] | None = None, op_types_to_exclude: list[str] | None = None, diff --git a/modelopt/onnx/quantization/quantize.py b/modelopt/onnx/quantization/quantize.py index daf785326..2d23b875a 100755 --- a/modelopt/onnx/quantization/quantize.py +++ b/modelopt/onnx/quantization/quantize.py @@ -53,6 +53,7 @@ from modelopt.onnx.quantization.graph_utils import ( cast_custom_ops, find_nodes_from_mha_to_exclude, + get_input_shapes, print_stat, remove_redundant_cast_nodes, validate_op_types_spelling, @@ -255,6 +256,8 @@ def quantize( Path to pre-calculated activation tensor ranges, also known as calibration cache. calibration_shapes: Input shapes used for calibration process. + It should be provided as a string representing the shape of each input tensors for one calibration step. + Example input shapes spec: input0:1x3x256x256,input1:1x3x128x128 calibration_eps: Priority order for the execution providers (EP) to calibrate the model. Any subset of ['NvTensorRtRtx', 'trt', 'cuda:x', 'dml:x', 'cpu'], where 'x' is the device id. @@ -467,6 +470,9 @@ def quantize( calibration_eps, ) + if not calibration_shapes: + calibration_shapes = get_input_shapes(onnx_path) + if quantize_mode in ["fp8", "int8"]: quantize_func = quantize_int8 if quantize_mode == "int8" else quantize_fp8 onnx_model = quantize_func( diff --git a/modelopt/onnx/trt_utils.py b/modelopt/onnx/trt_utils.py index 2231ccd00..e5a5d9a4a 100644 --- a/modelopt/onnx/trt_utils.py +++ b/modelopt/onnx/trt_utils.py @@ -266,6 +266,9 @@ def load_onnx_model( custom_ops = [] has_custom_op = False + # Infer shapes + onnx.shape_inference.infer_shapes_path(onnx_path) + # Load the model and weights onnx_model = onnx.load(onnx_path, load_external_data=True) size_threshold = 2 * (1024**3) # 2GB