Skip to content

Commit 7787b6d

Browse files
Add tests for Mean Dim operator
1 parent e1a0db8 commit 7787b6d

File tree

3 files changed

+153
-6
lines changed

3 files changed

+153
-6
lines changed

backends/nxp/tests/executors.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,15 +159,31 @@ def preprocess(self, data: np.ndarray | dict[int, numpy.ndarray]):
159159

160160

161161
class ToChannelFirstPreprocess(TFLiteIOPreprocess):
162+
def __init__(self, dim_0_reduced: bool | dict[int, bool] = False):
163+
self.dim_0_reduced = dim_0_reduced
164+
162165
def preprocess(self, data: np.ndarray | dict[int, np.ndarray]):
163-
def get_channel_first_permutation(tensor):
164-
return create_channels_last_to_channels_first_permutation(len(tensor.shape))
166+
def get_channel_first_permutation(tensor, dim_0_reduced):
167+
tensor_rank = len(tensor.shape)
168+
perm = create_channels_last_to_channels_first_permutation(tensor_rank)
169+
if dim_0_reduced and tensor_rank > 1:
170+
perm[0], perm[1] = perm[1], perm[0]
171+
return perm
172+
173+
transpose_fn = lambda x, rank: np.transpose(x, get_channel_first_permutation(x, rank))
174+
if isinstance(data, np.ndarray) and isinstance(self.dim_0_reduced, bool):
175+
preprocessed_data = transpose_fn(data, self.dim_0_reduced)
176+
177+
elif isinstance(data, dict) and isinstance(self.dim_0_reduced, bool):
178+
preprocessed_data = {k: transpose_fn(v, self.dim_0_reduced) for k, v in data.items()}
179+
180+
elif isinstance(data, dict) and isinstance(self.dim_0_reduced, dict):
181+
preprocessed_data = {k: transpose_fn(v, self.dim_0_reduced[k]) for k, v in data.items()}
165182

166-
transpose_fn = lambda x: np.transpose(x, get_channel_first_permutation(x))
167-
if isinstance(data, np.ndarray):
168-
preprocessed_data = transpose_fn(data)
169183
else:
170-
preprocessed_data = {k: transpose_fn(v) for k, v in data.items()}
184+
raise ValueError("Invalid combination of inputs. Data can be either np.ndarray or dict. If original number "
185+
"of dimension is used, it can be only int for np.ndarray data or dict of ints for dict "
186+
"data with same keys.")
171187
return preprocessed_data
172188

173189

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import numpy as np
2+
import pytest
3+
import torch
4+
from torch.export import ExportedProgram
5+
6+
from executorch.backends.nxp.backend.edge_program_converter import EdgeProgramToIRConverter
7+
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
8+
from executorch.backends.nxp.tests.executors import convert_run_compare, ToChannelFirstPreprocess, \
9+
ToChannelLastPreprocess
10+
from executorch.backends.nxp.tests.models import MeanDimConvModule, MeanDimLinearModule
11+
12+
13+
@pytest.fixture(autouse=True)
14+
def reseed_model_per_test_run():
15+
torch.manual_seed(23)
16+
np.random.seed(23)
17+
18+
19+
@pytest.mark.parametrize("input_shape, dim", [
20+
pytest.param((1, 4, 8, 8), (-1, -2), id="Dim -1, -2."),
21+
])
22+
def test_mean_dim_conv_quant_conversion(mocker, input_shape, dim, keeepdim=True):
23+
model = MeanDimConvModule(dim, keeepdim)
24+
25+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
26+
27+
# Run conversion
28+
_ = to_quantized_edge_program(model, input_shape)
29+
30+
# Capture generated model
31+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
32+
33+
# Capture converted program
34+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
35+
36+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8)
37+
38+
convert_run_compare(exported_program, tflite_input_preprocess=ToChannelLastPreprocess(), input_data=input_data,
39+
tflite_output_preprocess=ToChannelFirstPreprocess(), tfl_model=tflite_flatbuffers_model)
40+
41+
42+
@pytest.mark.parametrize("input_shape, dim", [
43+
pytest.param((1, 32), 0, id="Dim 0."),
44+
pytest.param((1, 32), 1, id="Dim 1."),
45+
])
46+
@pytest.mark.parametrize("keeepdim", [
47+
pytest.param(False, id="Don't keep dim."),
48+
pytest.param(True, id="Keep dim."),
49+
])
50+
def test_mean_dim_linear_unsupported_quant_conversion(mocker, input_shape, dim, keeepdim):
51+
model = MeanDimLinearModule(dim, keeepdim)
52+
53+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
54+
55+
# Run conversion
56+
edge_program = to_quantized_edge_program(model, input_shape).exported_program()
57+
nodes = list(edge_program.graph.nodes)
58+
59+
# Last 2 dimensions are not used or keepdim is False, cannot be converted to MeanDim, node is not delegated
60+
assert nodes[6].target.__name__ == 'aten.mean.dim'
61+
62+
# Capture generated model
63+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
64+
65+
# Capture converted program
66+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
67+
68+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8)
69+
70+
convert_run_compare(exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data)
71+
72+
73+
@pytest.mark.parametrize("input_shape, dim", [
74+
pytest.param((1, 4, 8, 8), 0, id="Dim 0."),
75+
pytest.param((1, 4, 8, 8), 2, id="Dim 2."),
76+
pytest.param((1, 4, 8, 8), -1, id="Dim -1."),
77+
pytest.param((1, 4, 8, 8), -2, id="Dim -2."),
78+
pytest.param((1, 4, 8, 8), (0, 1), id="Dim 0, 1."),
79+
pytest.param((1, 4, 8, 8), (1, 3), id="Dim 1, 3."),
80+
pytest.param((1, 4, 8, 8), (-1, -3), id="Dim -1, -3."),
81+
])
82+
@pytest.mark.parametrize("keeepdim", [
83+
pytest.param(False, id="Don't keep dim."),
84+
pytest.param(True, id="Keep dim."),
85+
])
86+
def test_mean_dim_conv_unsupported_quant_conversion(mocker, input_shape, dim, keeepdim):
87+
model = MeanDimConvModule(dim, keeepdim)
88+
89+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
90+
91+
# Run conversion
92+
edge_program = to_quantized_edge_program(model, input_shape).exported_program()
93+
nodes = list(edge_program.graph.nodes)
94+
95+
# Last 2 dimensions are not used or keepdim is False, cannot be converted to MeanDim, node is not delegated
96+
assert nodes[6].target.__name__ == 'aten.mean.dim'
97+
98+
# Capture generated model
99+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
100+
101+
# Capture converted program
102+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
103+
104+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8)
105+
106+
convert_run_compare(exported_program, tflite_input_preprocess=ToChannelLastPreprocess(), input_data=input_data,
107+
tflite_output_preprocess=ToChannelFirstPreprocess(), tfl_model=tflite_flatbuffers_model)

backends/nxp/tests/models.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,27 @@ def __init__(self):
249249
@staticmethod
250250
def forward(x):
251251
return x + x
252+
253+
254+
class MeanDimLinearModule(torch.nn.Module):
255+
def __init__(self, dim, keepdim):
256+
super().__init__()
257+
self.dim = dim
258+
self.keepdim = keepdim
259+
self.linear = torch.nn.Linear(32, 16)
260+
261+
def forward(self, x):
262+
x = self.linear(x)
263+
return torch.mean(x, dim=self.dim, keepdim=self.keepdim)
264+
265+
266+
class MeanDimConvModule(torch.nn.Module):
267+
def __init__(self, dim, keepdim):
268+
super().__init__()
269+
self.conv = Conv2dModule(stride=1, padding=1)
270+
self.dim = dim
271+
self.keepdim = keepdim
272+
273+
def forward(self, x):
274+
x = self.conv(x)
275+
return torch.mean(x, dim=self.dim, keepdim=self.keepdim)

0 commit comments

Comments
 (0)