Skip to content

Commit 88f6c22

Browse files
Extend tests for Linear, Addmmm, Mm converters
1 parent 6660b99 commit 88f6c22

File tree

4 files changed

+205
-49
lines changed

4 files changed

+205
-49
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import unittest
7+
8+
import kgb
9+
import numpy as np
10+
import torch
11+
12+
from executorch.backends.nxp.backend.edge_program_converter import (
13+
EdgeProgramToIRConverter,
14+
)
15+
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
16+
from executorch.backends.nxp.tests.executors import (
17+
convert_run_compare,
18+
graph_contains_any_of_ops,
19+
)
20+
from executorch.backends.nxp.tests.models import AddmmModule, LinearModule
21+
from executorch.exir.dialects._ops import ops as exir_ops
22+
from torch.export import ExportedProgram
23+
24+
25+
class TestAddmmConversion(unittest.TestCase):
26+
@classmethod
27+
def setUpClass(cls):
28+
torch.manual_seed(23)
29+
np.random.seed(42)
30+
31+
def test_addmm_conversion(self):
32+
with kgb.spy_on(
33+
EdgeProgramToIRConverter.convert_program, call_original=True
34+
) as converter_spy:
35+
input_shape = (1, 32)
36+
model = AddmmModule(input_shape[1])
37+
38+
edge_program = to_quantized_edge_program(
39+
model, input_shape
40+
).exported_program()
41+
42+
# Make sure that all nodes were delegated.
43+
assert not graph_contains_any_of_ops(
44+
graph=edge_program.graph, ops=[exir_ops.edge.aten.addmm.default]
45+
)
46+
assert any(
47+
"lowered_module" in node.name for node in edge_program.graph.nodes
48+
)
49+
50+
tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value
51+
exported_program: ExportedProgram = converter_spy.calls[-1].args[0]
52+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(
53+
np.int8
54+
)
55+
convert_run_compare(
56+
exported_program,
57+
input_data,
58+
tfl_model=tflite_flatbuffers_model,
59+
)
60+
61+
def test_linear_conversion__with_bias(self):
62+
with kgb.spy_on(
63+
EdgeProgramToIRConverter.convert_program, call_original=True
64+
) as converter_spy:
65+
input_shape = (10, 32)
66+
model = LinearModule(bias=True)
67+
68+
edge_program = to_quantized_edge_program(
69+
model, input_shape
70+
).exported_program()
71+
72+
# Make sure that all nodes were delegated.
73+
assert not graph_contains_any_of_ops(
74+
graph=edge_program.graph, ops=[exir_ops.edge.aten.addmm.default]
75+
)
76+
assert any(
77+
"lowered_module" in node.name for node in edge_program.graph.nodes
78+
)
79+
80+
tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value
81+
exported_program: ExportedProgram = converter_spy.calls[-1].args[0]
82+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(
83+
np.int8
84+
)
85+
convert_run_compare(
86+
exported_program,
87+
input_data,
88+
tfl_model=tflite_flatbuffers_model,
89+
)

backends/nxp/tests/ir/converter/node_converter/test_linear_converter.py

Lines changed: 0 additions & 49 deletions
This file was deleted.
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import unittest
7+
8+
import kgb
9+
import numpy as np
10+
import torch
11+
12+
from executorch.backends.nxp.backend.edge_program_converter import (
13+
EdgeProgramToIRConverter,
14+
)
15+
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
16+
from executorch.backends.nxp.tests.executors import (
17+
convert_run_compare,
18+
graph_contains_any_of_ops,
19+
)
20+
from executorch.backends.nxp.tests.models import LinearModule, MmModule
21+
from executorch.exir.dialects._ops import ops as exir_ops
22+
from torch.export import ExportedProgram
23+
24+
25+
class TestMmConversion(unittest.TestCase):
26+
@classmethod
27+
def setUpClass(cls):
28+
torch.manual_seed(23)
29+
np.random.seed(42)
30+
31+
def test_mm_conversion(self):
32+
with kgb.spy_on(
33+
EdgeProgramToIRConverter.convert_program, call_original=True
34+
) as converter_spy:
35+
input_shape = (1, 32)
36+
model = MmModule(input_shape[1])
37+
38+
edge_program = to_quantized_edge_program(
39+
model, input_shape
40+
).exported_program()
41+
42+
# Make sure that all nodes were delegated.
43+
assert not graph_contains_any_of_ops(
44+
graph=edge_program.graph, ops=[exir_ops.edge.aten.mm.default]
45+
)
46+
assert any(
47+
"lowered_module" in node.name for node in edge_program.graph.nodes
48+
)
49+
50+
tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value
51+
exported_program: ExportedProgram = converter_spy.calls[-1].args[0]
52+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(
53+
np.int8
54+
)
55+
convert_run_compare(
56+
exported_program,
57+
input_data,
58+
tfl_model=tflite_flatbuffers_model,
59+
)
60+
61+
def test_linear_conversion__without_bias(self):
62+
with kgb.spy_on(
63+
EdgeProgramToIRConverter.convert_program, call_original=True
64+
) as converter_spy:
65+
input_shape = (10, 32)
66+
model = LinearModule(bias=False)
67+
68+
edge_program = to_quantized_edge_program(
69+
model, input_shape
70+
).exported_program()
71+
72+
# Make sure that all nodes were delegated.
73+
assert not graph_contains_any_of_ops(
74+
graph=edge_program.graph, ops=[exir_ops.edge.aten.mm.default]
75+
)
76+
assert any(
77+
"lowered_module" in node.name for node in edge_program.graph.nodes
78+
)
79+
80+
tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value
81+
exported_program: ExportedProgram = converter_spy.calls[-1].args[0]
82+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(
83+
np.int8
84+
)
85+
convert_run_compare(
86+
exported_program,
87+
input_data,
88+
tfl_model=tflite_flatbuffers_model,
89+
)

backends/nxp/tests/models.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import math
78
from typing import Callable, Collection, Union
89

910
import torch
@@ -169,6 +170,32 @@ def forward(self, x):
169170
return self.linear(x)
170171

171172

173+
class AddmmModule(torch.nn.Module):
174+
def __init__(self, in_channels: int):
175+
super().__init__()
176+
self.weight = torch.nn.Parameter(torch.empty(in_channels, in_channels))
177+
self.bias = torch.nn.Parameter(torch.empty(in_channels))
178+
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
179+
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
180+
bound = 1 / math.sqrt(fan_in)
181+
torch.nn.init.uniform_(self.bias, -bound, bound)
182+
self.eval()
183+
184+
def forward(self, x):
185+
return torch.addmm(self.bias, x, self.weight)
186+
187+
188+
class MmModule(torch.nn.Module):
189+
def __init__(self, in_channels: int):
190+
super().__init__()
191+
self.weight = torch.nn.Parameter(torch.empty(in_channels, in_channels))
192+
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
193+
self.eval()
194+
195+
def forward(self, x):
196+
return torch.mm(x, self.weight)
197+
198+
172199
class LinearSoftmaxModule(torch.nn.Module):
173200
def __init__(self):
174201
super().__init__()

0 commit comments

Comments
 (0)