Skip to content

Commit af80568

Browse files
committed
Arm backend: Add LSTM test for int16x8
Signed-off-by: Saoirse Stewart <[email protected]>
1 parent 67d8cef commit af80568

File tree

1 file changed

+77
-1
lines changed

1 file changed

+77
-1
lines changed

backends/arm/test/models/test_lstm_arm.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,14 @@
55

66
from typing import Tuple
77

8+
import pytest
89
import torch
10+
from executorch.backends.arm.quantizer.arm_quantizer import (
11+
get_symmetric_a16w8_quantization_config,
12+
TOSAQuantizer,
13+
)
914

10-
from executorch.backends.arm.test import common
15+
from executorch.backends.arm.test import common, conftest
1116
from executorch.backends.arm.test.tester.test_pipeline import (
1217
EthosU55PipelineINT,
1318
EthosU85PipelineINT,
@@ -16,6 +21,9 @@
1621
VgfPipeline,
1722
)
1823

24+
from executorch.backends.arm.tosa import TosaSpecification
25+
from executorch.backends.xnnpack.test.tester import Quantize
26+
1927
from torch.nn.quantizable.modules import rnn
2028

2129
input_t = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] # (h0, c0)
@@ -123,3 +131,71 @@ def test_lstm_vgf_FP():
123131
use_to_edge_transform_and_lower=True,
124132
)
125133
pipeline.run()
134+
135+
136+
def get_symmetric_a16w8_lstm_quantizer(per_channel_quantization=False):
137+
tosa_version = conftest.get_option("tosa_version")
138+
tosa_profiles = {
139+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
140+
}
141+
142+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
143+
quantizer.set_global(
144+
get_symmetric_a16w8_quantization_config(
145+
is_per_channel=per_channel_quantization, epsilon=2**-16
146+
)
147+
)
148+
149+
return Quantize(
150+
quantizer,
151+
get_symmetric_a16w8_quantization_config(
152+
is_per_channel=per_channel_quantization, epsilon=2**-16
153+
),
154+
)
155+
156+
157+
def test_lstm_16a8w_tosa_INT():
158+
"""Test LSTM model with 16A8W quantization (16-bit activations, 8-bit weights)"""
159+
160+
pipeline = TosaPipelineINT[input_t](
161+
TestLSTM.lstm,
162+
TestLSTM.model_example_inputs,
163+
aten_op=[],
164+
exir_op=[],
165+
per_channel_quantization=False,
166+
use_to_edge_transform_and_lower=True,
167+
tosa_extensions=["int16"],
168+
)
169+
170+
pipeline.change_args("quantize", get_symmetric_a16w8_lstm_quantizer())
171+
pipeline.run()
172+
173+
174+
@pytest.mark.xfail(
175+
reason="MLETORCH-1452: AssertionError: Output 0 does not match reference output."
176+
)
177+
@common.XfailIfNoCorstone300
178+
def test_lstm_16a8w_u55_INT():
179+
pipeline = EthosU55PipelineINT[input_t](
180+
TestLSTM.lstm,
181+
TestLSTM.model_example_inputs,
182+
aten_ops=[],
183+
exir_ops=[],
184+
use_to_edge_transform_and_lower=True,
185+
)
186+
187+
pipeline.change_args("quantize", get_symmetric_a16w8_lstm_quantizer())
188+
pipeline.run()
189+
190+
191+
@common.XfailIfNoCorstone320
192+
def test_lstm_16a8w_u85_INT():
193+
pipeline = EthosU85PipelineINT[input_t](
194+
TestLSTM.lstm,
195+
TestLSTM.model_example_inputs,
196+
aten_ops=[],
197+
exir_ops=[],
198+
use_to_edge_transform_and_lower=True,
199+
)
200+
pipeline.change_args("quantize", get_symmetric_a16w8_lstm_quantizer())
201+
pipeline.run()

0 commit comments

Comments
 (0)