Skip to content

Commit cd7d115

Browse files
authored
Merge branch 'pytorch:main' into main
2 parents 52eeb46 + 98e4dd5 commit cd7d115

38 files changed

+432
-192
lines changed

backends/arm/test/conftest.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import shutil
1212
import subprocess
1313
import sys
14-
from enum import auto, Enum
1514
from typing import Any
1615

1716
import pytest
@@ -22,30 +21,24 @@
2221
"""
2322

2423

25-
class arm_test_options(Enum):
26-
quantize_io = auto()
27-
corstone_fvp = auto()
28-
fast_fvp = auto()
29-
30-
31-
_test_options: dict[arm_test_options, Any] = {}
32-
3324
# ==== Pytest hooks ====
3425

3526

3627
def pytest_configure(config):
28+
pytest._test_options = {}
29+
3730
if config.option.arm_quantize_io:
3831
_load_libquantized_ops_aot_lib()
39-
_test_options[arm_test_options.quantize_io] = True
32+
pytest._test_options["quantize_io"] = True
4033
if config.option.arm_run_corstoneFVP:
4134
corstone300_exists = shutil.which("FVP_Corstone_SSE-300_Ethos-U55")
4235
corstone320_exists = shutil.which("FVP_Corstone_SSE-320")
4336
if not (corstone300_exists and corstone320_exists):
4437
raise RuntimeError(
4538
"Tests are run with --arm_run_corstoneFVP but corstone FVP is not installed."
4639
)
47-
_test_options[arm_test_options.corstone_fvp] = True
48-
_test_options[arm_test_options.fast_fvp] = config.option.fast_fvp
40+
pytest._test_options["corstone_fvp"] = True
41+
pytest._test_options["fast_fvp"] = config.option.fast_fvp
4942
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
5043

5144

@@ -131,9 +124,7 @@ def expectedFailureOnFVP(test_item):
131124
# ==== End of Custom Pytest decorators =====
132125

133126

134-
def is_option_enabled(
135-
option: str | arm_test_options, fail_if_not_enabled: bool = False
136-
) -> bool:
127+
def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool:
137128
"""
138129
Returns whether an option is successfully enabled, i.e. if the flag was
139130
given to pytest and the necessary requirements are available.
@@ -144,10 +135,8 @@ def is_option_enabled(
144135
The optional parameter 'fail_if_not_enabled' makes the function raise
145136
a RuntimeError instead of returning False.
146137
"""
147-
if isinstance(option, str):
148-
option = arm_test_options[option.lower()]
149138

150-
if option in _test_options and _test_options[option]:
139+
if option in pytest._test_options and pytest._test_options[option]:
151140
return True
152141
else:
153142
if fail_if_not_enabled:
@@ -156,15 +145,15 @@ def is_option_enabled(
156145
return False
157146

158147

159-
def get_option(option: arm_test_options) -> Any | None:
148+
def get_option(option: str) -> Any | None:
160149
"""
161150
Returns the value of an pytest option if it is set, otherwise None.
162151
163152
Args:
164-
option (arm_test_options): The option to check for.
153+
option (str): The option to check for.
165154
"""
166-
if option in _test_options:
167-
return _test_options[option]
155+
if option in pytest._test_options:
156+
return pytest._test_options[option]
168157
return None
169158

170159

backends/arm/test/ops/test_depthwise_conv.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,11 @@
156156
("two_dw_conv2d", two_dw_conv2d),
157157
]
158158

159-
testsuite_conv2d_u85 = [
159+
testsuite_conv2d_u85_xfails = [
160160
("2x2_1x6x4x4_gp6_st1", dw_conv2d_2x2_1x6x4x4_gp6_st1),
161161
("3x3_1x3x256x256_gp3_st1", dw_conv2d_3x3_1x3x256x256_gp3_st1),
162162
("3x3_1x4x256x256_gp4_st1", dw_conv2d_3x3_1x4x256x256_gp4_st1),
163163
("3x3_1x4x256x256_gp4_nobias", dw_conv2d_3x3_1x4x256x256_gp4_nobias),
164-
]
165-
166-
testsuite_conv2d_u85_xfails = [
167164
("3x3_2x8x198x198_gp8_st3", dw_conv2d_3x3_2x8x198x198_gp8_st3),
168165
("two_dw_conv2d", two_dw_conv2d),
169166
]
@@ -260,6 +257,7 @@ def test_dw_conv_tosa_BI(self, test_name: str, model: torch.nn.Module):
260257
) # Works
261258

262259
@parameterized.expand(testsuite_conv2d, skip_on_empty=True)
260+
@unittest.expectedFailure
263261
def test_dw_conv2d_u55_BI(
264262
self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False
265263
):
@@ -286,7 +284,7 @@ def test_dw_conv1d_u55_BI(
286284
model.get_inputs(),
287285
)
288286

289-
@parameterized.expand(testsuite_conv1d + testsuite_conv2d_u85)
287+
@parameterized.expand(testsuite_conv1d[2:])
290288
def test_dw_conv_u85_BI(
291289
self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False
292290
):
@@ -298,8 +296,12 @@ def test_dw_conv_u85_BI(
298296
model.get_inputs(),
299297
)
300298

299+
testsuite_conv2d_u85_xfails.remove(
300+
("3x3_1x3x256x256_gp3_st1", dw_conv2d_3x3_1x3x256x256_gp3_st1)
301+
) # Works
302+
301303
# All test cases except 3x3_1x3x256x256_gp3_st1 have numerical issues on FVP. MLETORCH-520
302-
@parameterized.expand(testsuite_conv2d_u85_xfails)
304+
@parameterized.expand(testsuite_conv2d_u85_xfails + testsuite_conv1d[:2])
303305
@conftest.expectedFailureOnFVP
304306
def test_dw_conv_u85_BI_xfails(
305307
self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False

backends/arm/test/ops/test_div.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -183,21 +183,8 @@ def test_div_tosa_BI(
183183
test_data = (input_, other_)
184184
self._test_div_tosa_BI_pipeline(self.Div(), test_data)
185185

186-
@parameterized.expand(test_data_suite[:2])
187-
def test_div_u55_BI(
188-
self,
189-
test_name: str,
190-
input_: Union[torch.Tensor, torch.types.Number],
191-
other_: Union[torch.Tensor, torch.types.Number],
192-
rounding_mode: Optional[str] = None,
193-
):
194-
test_data = (input_, other_)
195-
self._test_div_ethos_BI_pipeline(
196-
self.Div(), common.get_u55_compile_spec(), test_data
197-
)
198-
199186
# Numerical issues on FVP likely due to mul op, MLETORCH-521
200-
@parameterized.expand(test_data_suite[2:])
187+
@parameterized.expand(test_data_suite)
201188
@conftest.expectedFailureOnFVP
202189
def test_div_u55_BI_xfails(
203190
self,
@@ -211,21 +198,8 @@ def test_div_u55_BI_xfails(
211198
self.Div(), common.get_u55_compile_spec(), test_data
212199
)
213200

214-
@parameterized.expand(test_data_suite[:2])
215-
def test_div_u85_BI(
216-
self,
217-
test_name: str,
218-
input_: Union[torch.Tensor, torch.types.Number],
219-
other_: Union[torch.Tensor, torch.types.Number],
220-
rounding_mode: Optional[str] = None,
221-
):
222-
test_data = (input_, other_)
223-
self._test_div_ethos_BI_pipeline(
224-
self.Div(), common.get_u85_compile_spec(), test_data
225-
)
226-
227201
# Numerical issues on FVP likely due to mul op, MLETORCH-521
228-
@parameterized.expand(test_data_suite[2:])
202+
@parameterized.expand(test_data_suite)
229203
@conftest.expectedFailureOnFVP
230204
def test_div_u85_BI_xfails(
231205
self,

backends/arm/test/ops/test_mul.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,9 @@ def test_mul_tosa_BI(
152152
test_data = (input_, other_)
153153
self._test_mul_tosa_BI_pipeline(self.Mul(), test_data)
154154

155+
# Numerical issues on FVP, MLETORCH-521
155156
@parameterized.expand(test_data_sute)
157+
@conftest.expectedFailureOnFVP
156158
def test_mul_u55_BI(
157159
self,
158160
test_name: str,
@@ -164,7 +166,10 @@ def test_mul_u55_BI(
164166
common.get_u55_compile_spec(), self.Mul(), test_data
165167
)
166168

167-
@parameterized.expand(test_data_sute)
169+
# Numerical issues on FVP, MLETORCH-521
170+
# test_data_sute[0] works on U85
171+
@parameterized.expand(test_data_sute[1:])
172+
@conftest.expectedFailureOnFVP
168173
def test_mul_u85_BI(
169174
self,
170175
test_name: str,

backends/arm/test/runner_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import numpy as np
1818
import torch
1919

20-
from executorch.backends.arm.test.conftest import arm_test_options, is_option_enabled
20+
from executorch.backends.arm.test.conftest import is_option_enabled
2121

2222
from torch.export import ExportedProgram
2323
from torch.fx.node import Node
@@ -251,7 +251,7 @@ def run_corstone(
251251
cmd_line += f" -i {input_path}"
252252

253253
ethos_u_extra_args = ""
254-
if is_option_enabled(arm_test_options.fast_fvp):
254+
if is_option_enabled("fast_fvp"):
255255
ethos_u_extra_args = ethos_u_extra_args + "--fast"
256256

257257
command_args = {

backends/cadence/aot/ops_registrations.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,10 @@
146146
"quantized_fully_connected(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
147147
"Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
148148
)
149-
149+
lib.define(
150+
"quantized_fully_connected.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
151+
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
152+
)
150153

151154
# ------------------------------------ #
152155
# Migrated from custom_ops.ymal #
@@ -192,6 +195,10 @@
192195
"quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
193196
"Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
194197
)
198+
lib.define(
199+
"quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
200+
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
201+
)
195202
lib.define(
196203
"quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, "
197204
"Tensor indices, bool pruned_weights=False, *, Tensor(a!) out) -> Tensor(a!)"
@@ -595,6 +602,28 @@ def quantized_fully_connected_meta(
595602
bias: torch.Tensor,
596603
in_zero_point: int,
597604
weight_zero_point: torch.Tensor,
605+
out_multiplier: torch.Tensor,
606+
out_shift: torch.Tensor,
607+
out_zero_point: int,
608+
offset: Optional[torch.Tensor],
609+
) -> torch.Tensor:
610+
# src comes in shape [leading_dims, in_dim]
611+
# weight comes in shape [out_dim, in_dim]
612+
# output comes in empty with shape [leading_dims, out_dim]
613+
out_size = list(src.size())
614+
weight_size = list(weight.size())
615+
assert len(weight_size) == 2
616+
out_size[-1] = weight_size[0]
617+
return src.new_empty(out_size, dtype=src.dtype)
618+
619+
620+
@register_fake("cadence::quantized_fully_connected.per_tensor")
621+
def quantized_fully_connected_per_tensor_meta(
622+
src: torch.Tensor,
623+
weight: torch.Tensor,
624+
bias: torch.Tensor,
625+
in_zero_point: int,
626+
weight_zero_point: int,
598627
out_multiplier: int,
599628
out_shift: int,
600629
out_zero_point: int,
@@ -607,7 +636,7 @@ def quantized_fully_connected_meta(
607636
weight_size = list(weight.size())
608637
assert len(weight_size) == 2
609638
out_size[-1] = weight_size[0]
610-
return src.new_empty(out_size, dtype=torch.uint8)
639+
return src.new_empty(out_size, dtype=src.dtype)
611640

612641

613642
@register_fake("cadence::convolution")

backends/cadence/aot/replace_ops.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# 3. functions that replace an ATen op with another semantically equivalent ATen op.
1010
# 4. functions that concretize optional args.
1111

12+
# pyre-unsafe
13+
1214
import math
1315
from operator import neg
1416
from typing import cast, Dict, Iterable, Sequence, Set, Tuple
@@ -1698,12 +1700,6 @@ def call_operator(self, op, args, kwargs, meta):
16981700
if leading_dims != 1:
16991701
return super().call_operator(op, args, kwargs, meta)
17001702

1701-
# If the op is quantized::linear, but per-channel quantized, bail.
1702-
if op == exir_ops.edge.cadence.quantized_linear.default:
1703-
weight = args[1].to_tensor() if isinstance(args[1], ProxyValue) else args[1]
1704-
if weight.shape != [1]:
1705-
return super().call_operator(op, args, kwargs, meta)
1706-
17071703
# Replace the linear with fully connected op
17081704
return super().call_operator(
17091705
self.linear_to_fc_op[op],
@@ -1893,6 +1889,10 @@ class ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass(ExportPass):
18931889
exir_ops.edge.cadence.quantized_conv.per_tensor,
18941890
[8, 9, 12, 13],
18951891
),
1892+
exir_ops.edge.cadence.quantized_fully_connected: (
1893+
exir_ops.edge.cadence.quantized_fully_connected.per_tensor,
1894+
[4, 5, 6],
1895+
),
18961896
exir_ops.edge.cadence.quantized_layer_norm: (
18971897
exir_ops.edge.cadence.quantized_layer_norm.per_tensor,
18981898
[1, 2],

0 commit comments

Comments
 (0)