Skip to content

Commit 6bd2c2d

Browse files
Add test for Adaptive Average Pool operator
1 parent a9b1f19 commit 6bd2c2d

File tree

2 files changed

+135
-0
lines changed

2 files changed

+135
-0
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import numpy as np
2+
import pytest
3+
import torch
4+
from torch.export import ExportedProgram
5+
6+
from executorch.backends.nxp.backend.edge_program_converter import EdgeProgramToIRConverter
7+
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
8+
from executorch.backends.nxp.tests.executors import convert_run_compare, ToChannelLastPreprocess, \
9+
ToChannelFirstPreprocess
10+
from executorch.backends.nxp.tests.models import AdaptiveAvgPool2dConvModule, AdaptiveAvgPool2dConvMeanDimModule
11+
12+
13+
@pytest.fixture(autouse=True)
14+
def reseed_model_per_test_run():
15+
torch.manual_seed(23)
16+
np.random.seed(23)
17+
18+
19+
@pytest.mark.parametrize("input_shape, output_size", [
20+
pytest.param((1, 4, 16, 16), (4, 4), id="Pooling with equal height and width kernel."),
21+
pytest.param((1, 4, 16, 16), (8, 8), id="Pooling with equal height and width kernel."),
22+
pytest.param((1, 4, 16, 16), (4, 8), id="Pooling with height > width kernel."),
23+
pytest.param((1, 4, 16, 22), (4, 11), id="Pooling with height > width kernel."),
24+
pytest.param((1, 4, 32, 32), (16, 4), id="Pooling with height < width kernel."),
25+
pytest.param((1, 4, 32, 16), (16, 4), id="Pooling with height < width kernel."),
26+
])
27+
def test_adaptive_avg_pool_2d_delegated_quant_conversion(mocker, input_shape, output_size):
28+
model = AdaptiveAvgPool2dConvModule(output_size)
29+
30+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
31+
32+
# Run conversion
33+
edge_program = to_quantized_edge_program(model, input_shape).exported_program()
34+
nodes = [str(node) for node in edge_program.graph.nodes]
35+
36+
# Input size is a multiple of output size, can be converted to AveragePool, node is delegated
37+
assert 'aten__adaptive_avg_pool2d_default' not in nodes
38+
39+
# Capture generated model
40+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
41+
42+
# Capture converted program
43+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
44+
45+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8)
46+
47+
convert_run_compare(exported_program, tflite_input_preprocess=ToChannelLastPreprocess(), tfl_model=tflite_flatbuffers_model,
48+
tflite_output_preprocess=ToChannelFirstPreprocess(), input_data=input_data, atol=1)
49+
50+
51+
@pytest.mark.parametrize("input_shape, output_size", [
52+
pytest.param((1, 4, 16, 16), (6, 6), id="Pooling with equal height and width kernel."),
53+
pytest.param((1, 4, 16, 16), (4, 7), id="Pooling with height > width kernel."),
54+
pytest.param((1, 4, 16, 22), (4, 10), id="Pooling with height > width kernel."),
55+
pytest.param((1, 4, 32, 32), (14, 7), id="Pooling with height < width kernel."),
56+
pytest.param((1, 4, 32, 16), (15, 5), id="Pooling with height < width kernel."),
57+
])
58+
def test_adaptive_avg_pool_2d_non_delegated_quant_conversion(mocker, input_shape, output_size):
59+
model = AdaptiveAvgPool2dConvModule(output_size)
60+
61+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
62+
63+
# Run conversion
64+
edge_program = to_quantized_edge_program(model, input_shape).exported_program()
65+
nodes = list(edge_program.graph.nodes)
66+
67+
# Input size is not a multiple of output size, cannot be converted to AveragePool, node is not delegated
68+
assert str(nodes[6]) == 'aten__adaptive_avg_pool2d_default'
69+
70+
# Capture generated model
71+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
72+
73+
# Capture converted program
74+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
75+
76+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8)
77+
78+
convert_run_compare(exported_program, tflite_input_preprocess=ToChannelLastPreprocess(), tfl_model=tflite_flatbuffers_model,
79+
tflite_output_preprocess=ToChannelFirstPreprocess(), input_data=input_data, atol=1)
80+
81+
82+
def test_adaptive_avg_pool_2d_mean_dim_quant_conversion(mocker):
83+
input_shape = (1, 4, 16, 16)
84+
model = AdaptiveAvgPool2dConvMeanDimModule()
85+
86+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
87+
88+
# Run conversion
89+
_ = to_quantized_edge_program(model, input_shape)
90+
91+
# Capture generated model
92+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
93+
94+
# Capture converted program
95+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
96+
97+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8)
98+
99+
convert_run_compare(exported_program, tflite_input_preprocess=ToChannelLastPreprocess(), tfl_model=tflite_flatbuffers_model,
100+
tflite_output_preprocess=ToChannelFirstPreprocess(), input_data=input_data)

backends/nxp/tests/models.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,41 @@ def forward(self, x):
191191
return self.avg_pool(x)
192192

193193

194+
class AdaptiveAvgPool2dModule(torch.nn.Module):
195+
def __init__(self, output_size):
196+
super().__init__()
197+
198+
self.adaptive_avg_pool = torch.nn.AdaptiveAvgPool2d(output_size=output_size)
199+
200+
def forward(self, x):
201+
return self.adaptive_avg_pool(x)
202+
203+
204+
class AdaptiveAvgPool2dConvModule(torch.nn.Module):
205+
def __init__(self, output_size):
206+
super().__init__()
207+
208+
self.conv = Conv2dModule(padding=1)
209+
self.adaptive_avg_pool = torch.nn.AdaptiveAvgPool2d(output_size=output_size)
210+
211+
def forward(self, x):
212+
x = self.conv(x)
213+
return self.adaptive_avg_pool(x)
214+
215+
216+
class AdaptiveAvgPool2dConvMeanDimModule(torch.nn.Module):
217+
def __init__(self):
218+
super().__init__()
219+
220+
self.conv = Conv2dModule()
221+
self.adaptive_avg_pool = torch.nn.AdaptiveAvgPool2d(output_size=(1, 1))
222+
223+
def forward(self, x):
224+
x = self.conv(x)
225+
x = self.adaptive_avg_pool(x)
226+
return x
227+
228+
194229
class ReLUModule(torch.nn.Module):
195230
def __init__(self):
196231
super().__init__()

0 commit comments

Comments
 (0)