Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 52 additions & 31 deletions backends/nxp/tests/executorch_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Callable

import torch

from executorch import exir
Expand All @@ -29,7 +32,13 @@
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e


def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]):
@dataclass
class ModelInputSpec:
shape: tuple[int, ...]
dtype: torch.dtype = torch.float32


def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor, ...]]):
quantizer = NeutronQuantizer()

m = prepare_pt2e(model, quantizer)
Expand All @@ -40,35 +49,52 @@ def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]):
return m


def get_random_float_data(input_shapes: tuple[int] | list[tuple[int]]):
# TODO: Replace with something more robust.
return (
(torch.randn(input_shapes),)
if type(input_shapes) is tuple
else tuple(torch.randn(input_shape) for input_shape in input_shapes)
)
def get_random_calibration_inputs(
input_spec: tuple[ModelInputSpec, ...]
) -> list[tuple[torch.Tensor, ...]]:
return [
tuple([torch.randn(spec.shape, dtype=spec.dtype) for spec in input_spec])
for _ in range(4)
]


def to_model_input_spec(
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]]
) -> tuple[ModelInputSpec, ...]:

if isinstance(input_spec, tuple) and all(
isinstance(spec, ModelInputSpec) for spec in input_spec
):
return input_spec

elif isinstance(input_spec, tuple) and all(
isinstance(spec, int) for spec in input_spec
):
return (ModelInputSpec(input_spec),)

elif isinstance(input_spec, list) and all(
isinstance(input_shape, tuple) for input_shape in input_spec
):
return tuple([ModelInputSpec(spec) for spec in input_spec])
else:
raise TypeError(f"Unsupported type {type(input_spec)}")


def to_quantized_edge_program(
model: torch.nn.Module,
input_shapes: tuple[int, ...] | list[tuple[int, ...]],
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]],
operators_not_to_delegate: list[str] = None,
get_calibration_inputs_fn: Callable[
[tuple[ModelInputSpec, ...]], list[tuple[torch.Tensor, ...]]
] = get_random_calibration_inputs,
target="imxrt700",
neutron_converter_flavor="SDK_25_03",
remove_quant_io_ops=False,
custom_delegation_options=CustomDelegationOptions(), # noqa B008
) -> EdgeProgramManager:
if isinstance(input_shapes, list):
assert all(isinstance(input_shape, tuple) for input_shape in input_shapes), (
"For multiple inputs, provide" " list[tuple[int]]."
)
calibration_inputs = get_calibration_inputs_fn(to_model_input_spec(input_spec))

calibration_inputs = [get_random_float_data(input_shapes) for _ in range(4)]
example_input = (
(torch.ones(input_shapes),)
if type(input_shapes) is tuple
else tuple(torch.ones(input_shape) for input_shape in input_shapes)
)
example_input = calibration_inputs[0]

exir_program_aten = torch.export.export_for_training(
model, example_input, strict=True
Expand Down Expand Up @@ -104,27 +130,22 @@ def to_quantized_edge_program(


def to_quantized_executorch_program(
model: torch.nn.Module, input_shapes: tuple[int, ...] | list[tuple[int, ...]]
model: torch.nn.Module,
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]],
) -> ExecutorchProgramManager:
edge_program_manager = to_quantized_edge_program(model, input_shapes)
edge_program_manager = to_quantized_edge_program(model, input_spec)

return edge_program_manager.to_executorch(
config=ExecutorchBackendConfig(extract_delegate_segments=False)
)


def to_edge_program(
model: nn.Module, input_shapes: tuple[int, ...] | list[tuple[int, ...]]
model: nn.Module,
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]],
) -> EdgeProgramManager:
if isinstance(input_shapes, list):
assert all(isinstance(input_shape, tuple) for input_shape in input_shapes), (
"For multiple inputs, provide" " list[tuple[int]]."
)
calibration_inputs = get_random_calibration_inputs(to_model_input_spec(input_spec))

example_input = (
(torch.ones(input_shapes),)
if type(input_shapes) is tuple
else tuple(torch.ones(input_shape) for input_shape in input_shapes)
)
example_input = calibration_inputs[0]
exir_program = torch.export.export(model, example_input)
return exir.to_edge(exir_program)
Loading