From 1546f752357ff33a017f2022a0c2f67f6ef410e7 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 5 Dec 2025 14:17:12 -0500 Subject: [PATCH 1/3] Add check for input shapes vs calibration data shapes Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- docs/source/guides/8_autocast.rst | 1 + modelopt/onnx/autocast/referencerunner.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/docs/source/guides/8_autocast.rst b/docs/source/guides/8_autocast.rst index 356d54c56..4ad39e969 100644 --- a/docs/source/guides/8_autocast.rst +++ b/docs/source/guides/8_autocast.rst @@ -110,6 +110,7 @@ Best Practices #. **Validate with Real Data**: - Provide representative input data using the ``calibration_data`` option for more accurate node classification. + - The input names and shapes in ``calibration_data`` should match the ones in the given ONNX model. #. **Control Reduction Depth**: - Use ``max_depth_of_reduction`` to limit the depth of reduction operations that can be converted to low precision. diff --git a/modelopt/onnx/autocast/referencerunner.py b/modelopt/onnx/autocast/referencerunner.py index 6da8264bb..18d3d534d 100644 --- a/modelopt/onnx/autocast/referencerunner.py +++ b/modelopt/onnx/autocast/referencerunner.py @@ -44,6 +44,10 @@ def __init__( """Initialize with ONNX model path.""" self.model = model self.input_names = [input.name for input in self.model.graph.input] + self.input_shapes = { + input.name: [s.dim_value for s in input.type.tensor_type.shape.dim] + for input in self.model.graph.input + } self.providers = self._prepare_ep_list_with_trt_plugin_path(providers, trt_plugins) def _prepare_ep_list_with_trt_plugin_path(self, providers, trt_plugins): @@ -69,12 +73,18 @@ def _load_inputs_from_npz(self, input_data_path): return [np.load(input_data_path)] def _validate_inputs(self, data_loader): - """Validate that input names match the model.""" + """Validate that input names and shapes match the model.""" if isinstance(data_loader, list) and ( isinstance(data_loader[0], (dict, np.lib.npyio.NpzFile)) ): if sorted(self.input_names) != sorted(data_loader[0].keys()): raise ValueError("Input names from ONNX model do not match provided input names.") + for inp_name, inp_shape in data_loader[0].items(): + if self.input_shapes[inp_name] != inp_shape.shape: + raise ValueError( + f"Input shape from '{inp_name}' does not match provided input shape: " + f"{self.input_shapes[inp_name]} vs {list(inp_shape.shape)}." + ) else: raise ValueError("Invalid input file.") From 69cc398329df3f892e857c1857d808646a6e396c Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 5 Dec 2025 14:27:30 -0500 Subject: [PATCH 2/3] Add check for input shapes vs calibration data shapes Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/autocast/referencerunner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modelopt/onnx/autocast/referencerunner.py b/modelopt/onnx/autocast/referencerunner.py index 18d3d534d..a5b682bd3 100644 --- a/modelopt/onnx/autocast/referencerunner.py +++ b/modelopt/onnx/autocast/referencerunner.py @@ -83,7 +83,8 @@ def _validate_inputs(self, data_loader): if self.input_shapes[inp_name] != inp_shape.shape: raise ValueError( f"Input shape from '{inp_name}' does not match provided input shape: " - f"{self.input_shapes[inp_name]} vs {list(inp_shape.shape)}." + f"{self.input_shapes[inp_name]} vs {list(inp_shape.shape)}. " + f"Please make sure that your calibration data matches the ONNX input shapes." ) else: raise ValueError("Invalid input file.") From 0e8f3d3a941ad091ba2fbb53deee37f760660f68 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Wed, 10 Dec 2025 17:06:09 -0500 Subject: [PATCH 3/3] Fix test: tuple to list Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/autocast/referencerunner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/onnx/autocast/referencerunner.py b/modelopt/onnx/autocast/referencerunner.py index a5b682bd3..8dc91ff08 100644 --- a/modelopt/onnx/autocast/referencerunner.py +++ b/modelopt/onnx/autocast/referencerunner.py @@ -80,7 +80,7 @@ def _validate_inputs(self, data_loader): if sorted(self.input_names) != sorted(data_loader[0].keys()): raise ValueError("Input names from ONNX model do not match provided input names.") for inp_name, inp_shape in data_loader[0].items(): - if self.input_shapes[inp_name] != inp_shape.shape: + if self.input_shapes[inp_name] != list(inp_shape.shape): raise ValueError( f"Input shape from '{inp_name}' does not match provided input shape: " f"{self.input_shapes[inp_name]} vs {list(inp_shape.shape)}. "