Skip to content

Commit ec00ecf

Browse files
committed
Update
[ghstack-poisoned]
2 parents 386e994 + a931676 commit ec00ecf

21 files changed

+903
-87
lines changed

backends/apple/coreml/test/tester.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
git # Copyright (c) Meta Platforms, Inc. and affiliates.
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type
7+
from typing import Any, List, Optional, Tuple
88

99
import executorch
1010
import executorch.backends.test.harness.stages as BaseStages
@@ -59,4 +59,3 @@ def __init__(
5959
example_inputs=example_inputs,
6060
dynamic_shapes=dynamic_shapes,
6161
)
62-

backends/test/compliance_suite/__init__.py

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
10+
import logging
111
import os
212
import unittest
313

414
from enum import Enum
515
from typing import Any, Callable, Tuple
616

7-
import logging
817
import torch
918
from executorch.backends.test.harness import Tester
1019

@@ -15,20 +24,23 @@
1524
# Read enabled backends from the environment variable. Enable all if
1625
# not specified (signalled by None).
1726
def get_enabled_backends():
18-
et_test_backends = os.environ.get("ET_TEST_BACKENDS")
27+
et_test_backends = os.environ.get("ET_TEST_ENABLED_BACKENDS")
1928
if et_test_backends is not None:
2029
return et_test_backends.split(",")
2130
else:
2231
return None
2332

33+
2434
_ENABLED_BACKENDS = get_enabled_backends()
2535

36+
2637
def is_backend_enabled(backend):
2738
if _ENABLED_BACKENDS is None:
2839
return True
2940
else:
3041
return backend in _ENABLED_BACKENDS
3142

43+
3244
ALL_TEST_FLOWS = []
3345

3446
if is_backend_enabled("xnnpack"):
@@ -58,49 +70,71 @@ def is_backend_enabled(backend):
5870
torch.float64,
5971
]
6072

73+
FLOAT_DTYPES = [
74+
torch.float16,
75+
torch.float32,
76+
torch.float64,
77+
]
78+
79+
6180
class TestType(Enum):
6281
STANDARD = 1
6382
DTYPE = 2
6483

84+
6585
def dtype_test(func):
66-
setattr(func, "test_type", TestType.DTYPE)
86+
func.test_type = TestType.DTYPE
6787
return func
6888

89+
6990
def operator_test(cls):
7091
_create_tests(cls)
7192
return cls
7293

94+
7395
def _create_tests(cls):
7496
for key in dir(cls):
7597
if key.startswith("test_"):
7698
_expand_test(cls, key)
77-
99+
100+
78101
def _expand_test(cls, test_name: str):
79102
test_func = getattr(cls, test_name)
80-
for (flow_name, tester_factory) in ALL_TEST_FLOWS:
81-
_create_test_for_backend(cls, test_func, flow_name, tester_factory)
103+
for flow_name, tester_factory in ALL_TEST_FLOWS:
104+
_create_test_for_backend(cls, test_func, flow_name, tester_factory)
82105
delattr(cls, test_name)
83106

107+
108+
def _make_wrapped_test(test_func, tester_factory):
109+
def wrapped_test(self):
110+
test_func(self, tester_factory)
111+
112+
return tester_factory
113+
114+
115+
def _make_wrapped_dtype_test(test_func, dtype, tester_factory):
116+
def wrapped_test(self):
117+
test_func(self, dtype, tester_factory)
118+
119+
return wrapped_test
120+
121+
84122
def _create_test_for_backend(
85123
cls,
86124
test_func: Callable,
87125
flow_name: str,
88-
tester_factory: Callable[[torch.nn.Module, Tuple[Any]], Tester]
126+
tester_factory: Callable[[torch.nn.Module, Tuple[Any]], Tester],
89127
):
90128
test_type = getattr(test_func, "test_type", TestType.STANDARD)
91129

92130
if test_type == TestType.STANDARD:
93-
def wrapped_test(self):
94-
test_func(self, tester_factory)
95-
131+
wrapped_test = _make_wrapped_test(test_func, tester_factory)
96132
test_name = f"{test_func.__name__}_{flow_name}"
97133
setattr(cls, test_name, wrapped_test)
98134
elif test_type == TestType.DTYPE:
99135
for dtype in DTYPES:
100-
def wrapped_test(self):
101-
test_func(self, dtype, tester_factory)
102-
103-
dtype_name = str(dtype)[6:] # strip "torch."
136+
wrapped_test = _make_wrapped_dtype_test(test_func, dtype, tester_factory)
137+
dtype_name = str(dtype)[6:] # strip "torch."
104138
test_name = f"{test_func.__name__}_{dtype_name}_{flow_name}"
105139
setattr(cls, test_name, wrapped_test)
106140
else:
@@ -126,10 +160,4 @@ def _test_op(self, model, inputs, tester_factory):
126160

127161
# Only run the runtime test if the op was delegated.
128162
if is_delegated:
129-
(
130-
tester
131-
.to_executorch()
132-
.serialize()
133-
.run_method_and_compare_outputs()
134-
)
135-
163+
(tester.to_executorch().serialize().run_method_and_compare_outputs())
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1-
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
28

3-
# pyre-strict
49

510
from typing import Callable
611

@@ -12,10 +17,12 @@
1217
OperatorTest,
1318
)
1419

20+
1521
class Model(torch.nn.Module):
1622
def forward(self, x, y):
1723
return x + y
1824

25+
1926
class ModelAlpha(torch.nn.Module):
2027
def __init__(self, alpha):
2128
super().__init__()
@@ -24,6 +31,7 @@ def __init__(self, alpha):
2431
def forward(self, x, y):
2532
return torch.add(x, y, alpha=self.alpha)
2633

34+
2735
@operator_test
2836
class Add(OperatorTest):
2937
@dtype_test
@@ -34,41 +42,45 @@ def test_add_dtype(self, dtype, tester_factory: Callable) -> None:
3442
(torch.rand(2, 10) * 100).to(dtype),
3543
(torch.rand(2, 10) * 100).to(dtype),
3644
),
37-
tester_factory)
38-
45+
tester_factory,
46+
)
47+
3948
def test_add_f32_bcast_first(self, tester_factory: Callable) -> None:
4049
self._test_op(
41-
Model(),
50+
Model(),
4251
(
4352
torch.randn(5),
4453
torch.randn(1, 5, 1, 5),
4554
),
46-
tester_factory)
47-
55+
tester_factory,
56+
)
57+
4858
def test_add_f32_bcast_second(self, tester_factory: Callable) -> None:
4959
self._test_op(
50-
Model(),
60+
Model(),
5161
(
5262
torch.randn(4, 4, 2, 7),
5363
torch.randn(2, 7),
5464
),
55-
tester_factory)
65+
tester_factory,
66+
)
5667

5768
def test_add_f32_bcast_unary(self, tester_factory: Callable) -> None:
5869
self._test_op(
59-
Model(),
70+
Model(),
6071
(
6172
torch.randn(5),
6273
torch.randn(1, 1, 5),
6374
),
64-
tester_factory)
65-
75+
tester_factory,
76+
)
77+
6678
def test_add_f32_alpha(self, tester_factory: Callable) -> None:
6779
self._test_op(
68-
ModelAlpha(alpha=2),
80+
ModelAlpha(alpha=2),
6981
(
7082
torch.randn(1, 25),
7183
torch.randn(1, 25),
7284
),
73-
tester_factory)
74-
85+
tester_factory,
86+
)
Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1-
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
28

3-
# pyre-strict
49

510
from typing import Callable, Optional
611

@@ -12,10 +17,12 @@
1217
OperatorTest,
1318
)
1419

20+
1521
class Model(torch.nn.Module):
1622
def forward(self, x, y):
1723
return x / y
1824

25+
1926
class ModelWithRounding(torch.nn.Module):
2027
def __init__(self, rounding_mode: Optional[str]):
2128
super().__init__()
@@ -24,6 +31,7 @@ def __init__(self, rounding_mode: Optional[str]):
2431
def forward(self, x, y):
2532
return torch.div(x, y, rounding_mode=self.rounding_mode)
2633

34+
2735
@operator_test
2836
class Divide(OperatorTest):
2937
@dtype_test
@@ -32,51 +40,64 @@ def test_divide_dtype(self, dtype, tester_factory: Callable) -> None:
3240
Model(),
3341
(
3442
(torch.rand(2, 10) * 100).to(dtype),
35-
(torch.rand(2, 10) * 100 + 0.1).to(dtype), # Adding 0.1 to avoid division by zero
43+
(torch.rand(2, 10) * 100 + 0.1).to(
44+
dtype
45+
), # Adding 0.1 to avoid division by zero
3646
),
37-
tester_factory)
38-
47+
tester_factory,
48+
)
49+
3950
def test_divide_f32_bcast_first(self, tester_factory: Callable) -> None:
4051
self._test_op(
41-
Model(),
52+
Model(),
4253
(
4354
torch.randn(5),
44-
torch.randn(1, 5, 1, 5).abs() + 0.1, # Using abs and adding 0.1 to avoid division by zero
55+
torch.randn(1, 5, 1, 5).abs()
56+
+ 0.1, # Using abs and adding 0.1 to avoid division by zero
4557
),
46-
tester_factory)
47-
58+
tester_factory,
59+
)
60+
4861
def test_divide_f32_bcast_second(self, tester_factory: Callable) -> None:
4962
self._test_op(
50-
Model(),
63+
Model(),
5164
(
5265
torch.randn(4, 4, 2, 7),
53-
torch.randn(2, 7).abs() + 0.1, # Using abs and adding 0.1 to avoid division by zero
66+
torch.randn(2, 7).abs()
67+
+ 0.1, # Using abs and adding 0.1 to avoid division by zero
5468
),
55-
tester_factory)
69+
tester_factory,
70+
)
5671

5772
def test_divide_f32_bcast_unary(self, tester_factory: Callable) -> None:
5873
self._test_op(
59-
Model(),
74+
Model(),
6075
(
6176
torch.randn(5),
62-
torch.randn(1, 1, 5).abs() + 0.1, # Using abs and adding 0.1 to avoid division by zero
77+
torch.randn(1, 1, 5).abs()
78+
+ 0.1, # Using abs and adding 0.1 to avoid division by zero
6379
),
64-
tester_factory)
65-
80+
tester_factory,
81+
)
82+
6683
def test_divide_f32_trunc(self, tester_factory: Callable) -> None:
6784
self._test_op(
68-
ModelWithRounding(rounding_mode="trunc"),
85+
ModelWithRounding(rounding_mode="trunc"),
6986
(
7087
torch.randn(3, 4) * 10,
71-
torch.randn(3, 4).abs() + 0.1, # Using abs and adding 0.1 to avoid division by zero
88+
torch.randn(3, 4).abs()
89+
+ 0.1, # Using abs and adding 0.1 to avoid division by zero
7290
),
73-
tester_factory)
74-
91+
tester_factory,
92+
)
93+
7594
def test_divide_f32_floor(self, tester_factory: Callable) -> None:
7695
self._test_op(
77-
ModelWithRounding(rounding_mode="floor"),
96+
ModelWithRounding(rounding_mode="floor"),
7897
(
7998
torch.randn(3, 4) * 10,
80-
torch.randn(3, 4).abs() + 0.1, # Using abs and adding 0.1 to avoid division by zero
99+
torch.randn(3, 4).abs()
100+
+ 0.1, # Using abs and adding 0.1 to avoid division by zero
81101
),
82-
tester_factory)
102+
tester_factory,
103+
)

0 commit comments

Comments
 (0)