diff --git a/backends/nxp/tests/executorch_pipeline.py b/backends/nxp/tests/executorch_pipeline.py index 5de600c0ec7..a20559c8cf6 100644 --- a/backends/nxp/tests/executorch_pipeline.py +++ b/backends/nxp/tests/executorch_pipeline.py @@ -3,6 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Callable + import torch from executorch import exir @@ -29,7 +32,13 @@ from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e -def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]): +@dataclass +class ModelInputSpec: + shape: tuple[int, ...] + dtype: torch.dtype = torch.float32 + + +def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor, ...]]): quantizer = NeutronQuantizer() m = prepare_pt2e(model, quantizer) @@ -40,35 +49,52 @@ def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]): return m -def get_random_float_data(input_shapes: tuple[int] | list[tuple[int]]): - # TODO: Replace with something more robust. - return ( - (torch.randn(input_shapes),) - if type(input_shapes) is tuple - else tuple(torch.randn(input_shape) for input_shape in input_shapes) - ) +def get_random_calibration_inputs( + input_spec: tuple[ModelInputSpec, ...] +) -> list[tuple[torch.Tensor, ...]]: + return [ + tuple([torch.randn(spec.shape, dtype=spec.dtype) for spec in input_spec]) + for _ in range(4) + ] + + +def to_model_input_spec( + input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]] +) -> tuple[ModelInputSpec, ...]: + + if isinstance(input_spec, tuple) and all( + isinstance(spec, ModelInputSpec) for spec in input_spec + ): + return input_spec + + elif isinstance(input_spec, tuple) and all( + isinstance(spec, int) for spec in input_spec + ): + return (ModelInputSpec(input_spec),) + + elif isinstance(input_spec, list) and all( + isinstance(input_shape, tuple) for input_shape in input_spec + ): + return tuple([ModelInputSpec(spec) for spec in input_spec]) + else: + raise TypeError(f"Unsupported type {type(input_spec)}") def to_quantized_edge_program( model: torch.nn.Module, - input_shapes: tuple[int, ...] | list[tuple[int, ...]], + input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]], operators_not_to_delegate: list[str] = None, + get_calibration_inputs_fn: Callable[ + [tuple[ModelInputSpec, ...]], list[tuple[torch.Tensor, ...]] + ] = get_random_calibration_inputs, target="imxrt700", neutron_converter_flavor="SDK_25_03", remove_quant_io_ops=False, custom_delegation_options=CustomDelegationOptions(), # noqa B008 ) -> EdgeProgramManager: - if isinstance(input_shapes, list): - assert all(isinstance(input_shape, tuple) for input_shape in input_shapes), ( - "For multiple inputs, provide" " list[tuple[int]]." - ) + calibration_inputs = get_calibration_inputs_fn(to_model_input_spec(input_spec)) - calibration_inputs = [get_random_float_data(input_shapes) for _ in range(4)] - example_input = ( - (torch.ones(input_shapes),) - if type(input_shapes) is tuple - else tuple(torch.ones(input_shape) for input_shape in input_shapes) - ) + example_input = calibration_inputs[0] exir_program_aten = torch.export.export_for_training( model, example_input, strict=True @@ -104,9 +130,10 @@ def to_quantized_edge_program( def to_quantized_executorch_program( - model: torch.nn.Module, input_shapes: tuple[int, ...] | list[tuple[int, ...]] + model: torch.nn.Module, + input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]], ) -> ExecutorchProgramManager: - edge_program_manager = to_quantized_edge_program(model, input_shapes) + edge_program_manager = to_quantized_edge_program(model, input_spec) return edge_program_manager.to_executorch( config=ExecutorchBackendConfig(extract_delegate_segments=False) @@ -114,17 +141,11 @@ def to_quantized_executorch_program( def to_edge_program( - model: nn.Module, input_shapes: tuple[int, ...] | list[tuple[int, ...]] + model: nn.Module, + input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]], ) -> EdgeProgramManager: - if isinstance(input_shapes, list): - assert all(isinstance(input_shape, tuple) for input_shape in input_shapes), ( - "For multiple inputs, provide" " list[tuple[int]]." - ) + calibration_inputs = get_random_calibration_inputs(to_model_input_spec(input_spec)) - example_input = ( - (torch.ones(input_shapes),) - if type(input_shapes) is tuple - else tuple(torch.ones(input_shape) for input_shape in input_shapes) - ) + example_input = calibration_inputs[0] exir_program = torch.export.export(model, example_input) return exir.to_edge(exir_program)