diff --git a/docs/source/guides/8_autocast.rst b/docs/source/guides/8_autocast.rst index 356d54c56..4ad39e969 100644 --- a/docs/source/guides/8_autocast.rst +++ b/docs/source/guides/8_autocast.rst @@ -110,6 +110,7 @@ Best Practices #. **Validate with Real Data**: - Provide representative input data using the ``calibration_data`` option for more accurate node classification. + - The input names and shapes in ``calibration_data`` should match the ones in the given ONNX model. #. **Control Reduction Depth**: - Use ``max_depth_of_reduction`` to limit the depth of reduction operations that can be converted to low precision. diff --git a/modelopt/onnx/autocast/referencerunner.py b/modelopt/onnx/autocast/referencerunner.py index 6da8264bb..8dc91ff08 100644 --- a/modelopt/onnx/autocast/referencerunner.py +++ b/modelopt/onnx/autocast/referencerunner.py @@ -44,6 +44,10 @@ def __init__( """Initialize with ONNX model path.""" self.model = model self.input_names = [input.name for input in self.model.graph.input] + self.input_shapes = { + input.name: [s.dim_value for s in input.type.tensor_type.shape.dim] + for input in self.model.graph.input + } self.providers = self._prepare_ep_list_with_trt_plugin_path(providers, trt_plugins) def _prepare_ep_list_with_trt_plugin_path(self, providers, trt_plugins): @@ -69,12 +73,19 @@ def _load_inputs_from_npz(self, input_data_path): return [np.load(input_data_path)] def _validate_inputs(self, data_loader): - """Validate that input names match the model.""" + """Validate that input names and shapes match the model.""" if isinstance(data_loader, list) and ( isinstance(data_loader[0], (dict, np.lib.npyio.NpzFile)) ): if sorted(self.input_names) != sorted(data_loader[0].keys()): raise ValueError("Input names from ONNX model do not match provided input names.") + for inp_name, inp_shape in data_loader[0].items(): + if self.input_shapes[inp_name] != list(inp_shape.shape): + raise ValueError( + f"Input shape from '{inp_name}' does not match provided input shape: " + f"{self.input_shapes[inp_name]} vs {list(inp_shape.shape)}. " + f"Please make sure that your calibration data matches the ONNX input shapes." + ) else: raise ValueError("Invalid input file.")