Skip to content

Commit 795ab37

Browse files
Extend tests for Linear, Addmmm, Mm converters
1 parent ead74f2 commit 795ab37

File tree

4 files changed

+204
-49
lines changed

4 files changed

+204
-49
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import unittest
7+
8+
import kgb
9+
import numpy as np
10+
import torch
11+
12+
from executorch.backends.nxp.backend.edge_program_converter import (
13+
EdgeProgramToIRConverter,
14+
)
15+
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
16+
from executorch.backends.nxp.tests.executors import (
17+
convert_run_compare,
18+
graph_contains_any_of_ops,
19+
)
20+
from executorch.backends.nxp.tests.models import AddmmModule, LinearModule
21+
from executorch.exir.dialects._ops import ops as exir_ops
22+
from torch.export import ExportedProgram
23+
24+
25+
class TestAddmmConversion(unittest.TestCase):
26+
@classmethod
27+
def setUpClass(cls):
28+
torch.manual_seed(23)
29+
np.random.seed(42)
30+
31+
def test_addmm_conversion(self):
32+
with kgb.spy_on(
33+
EdgeProgramToIRConverter.convert_program, call_original=True
34+
) as converter_spy:
35+
input_shape = (1, 32)
36+
model = AddmmModule(input_shape[1])
37+
38+
edge_program = to_quantized_edge_program(
39+
model, input_shape
40+
).exported_program()
41+
42+
# Make sure that all nodes were delegated.
43+
assert not graph_contains_any_of_ops(
44+
graph=edge_program.graph, ops=[exir_ops.edge.aten.addmm.default]
45+
)
46+
assert any(
47+
"lowered_module" in node.name for node in edge_program.graph.nodes
48+
)
49+
50+
tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value
51+
exported_program: ExportedProgram = converter_spy.calls[-1].args[0]
52+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(
53+
np.int8
54+
)
55+
convert_run_compare(
56+
exported_program,
57+
input_data,
58+
tfl_model=tflite_flatbuffers_model,
59+
)
60+
61+
def test_linear_conversion__with_bias(self):
62+
with kgb.spy_on(
63+
EdgeProgramToIRConverter.convert_program, call_original=True
64+
) as converter_spy:
65+
input_shape = (10, 32)
66+
model = LinearModule(bias=True)
67+
68+
edge_program = to_quantized_edge_program(
69+
model, input_shape
70+
).exported_program()
71+
72+
# Make sure that all nodes were delegated.
73+
assert not graph_contains_any_of_ops(
74+
graph=edge_program.graph, ops=[exir_ops.edge.aten.addmm.default]
75+
)
76+
assert any(
77+
"lowered_module" in node.name for node in edge_program.graph.nodes
78+
)
79+
80+
tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value
81+
exported_program: ExportedProgram = converter_spy.calls[-1].args[0]
82+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(
83+
np.int8
84+
)
85+
convert_run_compare(
86+
exported_program,
87+
input_data,
88+
tfl_model=tflite_flatbuffers_model,
89+
)

backends/nxp/tests/ir/converter/node_converter/test_linear_converter.py

Lines changed: 0 additions & 49 deletions
This file was deleted.
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import unittest
7+
8+
import kgb
9+
import numpy as np
10+
import torch
11+
12+
from executorch.backends.nxp.backend.edge_program_converter import (
13+
EdgeProgramToIRConverter,
14+
)
15+
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
16+
from executorch.backends.nxp.tests.executors import (
17+
convert_run_compare,
18+
graph_contains_any_of_ops,
19+
)
20+
from executorch.backends.nxp.tests.models import LinearModule, MmModule
21+
from executorch.exir.dialects._ops import ops as exir_ops
22+
from torch.export import ExportedProgram
23+
24+
25+
class TestMmConversion(unittest.TestCase):
26+
@classmethod
27+
def setUpClass(cls):
28+
torch.manual_seed(23)
29+
np.random.seed(42)
30+
31+
def test_mm_conversion(self):
32+
with kgb.spy_on(
33+
EdgeProgramToIRConverter.convert_program, call_original=True
34+
) as converter_spy:
35+
input_shape = (1, 32)
36+
model = MmModule(input_shape[1])
37+
38+
edge_program = to_quantized_edge_program(
39+
model, input_shape
40+
).exported_program()
41+
42+
# Make sure that all nodes were delegated.
43+
assert not graph_contains_any_of_ops(
44+
graph=edge_program.graph, ops=[exir_ops.edge.aten.mm.default]
45+
)
46+
assert any(
47+
"lowered_module" in node.name for node in edge_program.graph.nodes
48+
)
49+
50+
tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value
51+
exported_program: ExportedProgram = converter_spy.calls[-1].args[0]
52+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(
53+
np.int8
54+
)
55+
convert_run_compare(
56+
exported_program,
57+
input_data,
58+
tfl_model=tflite_flatbuffers_model,
59+
)
60+
61+
def test_linear_conversion__without_bias(self):
62+
with kgb.spy_on(
63+
EdgeProgramToIRConverter.convert_program, call_original=True
64+
) as converter_spy:
65+
input_shape = (10, 32)
66+
model = LinearModule(bias=False)
67+
68+
edge_program = to_quantized_edge_program(
69+
model, input_shape
70+
).exported_program()
71+
72+
# Make sure that all nodes were delegated.
73+
assert not graph_contains_any_of_ops(
74+
graph=edge_program.graph, ops=[exir_ops.edge.aten.mm.default]
75+
)
76+
assert any(
77+
"lowered_module" in node.name for node in edge_program.graph.nodes
78+
)
79+
80+
tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value
81+
exported_program: ExportedProgram = converter_spy.calls[-1].args[0]
82+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(
83+
np.int8
84+
)
85+
convert_run_compare(
86+
exported_program,
87+
input_data,
88+
tfl_model=tflite_flatbuffers_model,
89+
)

backends/nxp/tests/models.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Callable, Collection, Union
88

99
import torch
10+
import math
1011

1112

1213
class Conv1dModule(torch.nn.Module):
@@ -168,6 +169,31 @@ def __init__(self, bias: bool):
168169
def forward(self, x):
169170
return self.linear(x)
170171

172+
class AddmmModule(torch.nn.Module):
173+
def __init__(self, in_channels: int):
174+
super().__init__()
175+
self.weight = torch.nn.Parameter(torch.empty(in_channels, in_channels))
176+
self.bias = torch.nn.Parameter(torch.empty(in_channels))
177+
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
178+
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
179+
bound = 1 / math.sqrt(fan_in)
180+
torch.nn.init.uniform_(self.bias, -bound, bound)
181+
self.eval()
182+
183+
def forward(self, x):
184+
return torch.addmm(self.bias, x, self.weight)
185+
186+
187+
class MmModule(torch.nn.Module):
188+
def __init__(self, in_channels: int):
189+
super().__init__()
190+
self.weight = torch.nn.Parameter(torch.empty(in_channels, in_channels))
191+
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
192+
self.eval()
193+
194+
def forward(self, x):
195+
return torch.mm(x, self.weight)
196+
171197

172198
class LinearSoftmaxModule(torch.nn.Module):
173199
def __init__(self):

0 commit comments

Comments
 (0)