Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/guides/8_autocast.rst
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ Best Practices
#. **Validate with Real Data**:

- Provide representative input data using the ``calibration_data`` option for more accurate node classification.
- The input names and shapes in ``calibration_data`` should match the ones in the given ONNX model.

#. **Control Reduction Depth**:
- Use ``max_depth_of_reduction`` to limit the depth of reduction operations that can be converted to low precision.
Expand Down
13 changes: 12 additions & 1 deletion modelopt/onnx/autocast/referencerunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def __init__(
"""Initialize with ONNX model path."""
self.model = model
self.input_names = [input.name for input in self.model.graph.input]
self.input_shapes = {
input.name: [s.dim_value for s in input.type.tensor_type.shape.dim]
for input in self.model.graph.input
}
self.providers = self._prepare_ep_list_with_trt_plugin_path(providers, trt_plugins)

def _prepare_ep_list_with_trt_plugin_path(self, providers, trt_plugins):
Expand All @@ -69,12 +73,19 @@ def _load_inputs_from_npz(self, input_data_path):
return [np.load(input_data_path)]

def _validate_inputs(self, data_loader):
"""Validate that input names match the model."""
"""Validate that input names and shapes match the model."""
if isinstance(data_loader, list) and (
isinstance(data_loader[0], (dict, np.lib.npyio.NpzFile))
):
if sorted(self.input_names) != sorted(data_loader[0].keys()):
raise ValueError("Input names from ONNX model do not match provided input names.")
for inp_name, inp_shape in data_loader[0].items():
if self.input_shapes[inp_name] != inp_shape.shape:
raise ValueError(
f"Input shape from '{inp_name}' does not match provided input shape: "
f"{self.input_shapes[inp_name]} vs {list(inp_shape.shape)}. "
f"Please make sure that your calibration data matches the ONNX input shapes."
)
else:
raise ValueError("Invalid input file.")

Expand Down
Loading