Skip to content

Commit 1839f4b

Browse files
authored
Refactor suite namespace (#103)
1 parent 9c97184 commit 1839f4b

File tree

12 files changed

+80
-41
lines changed

12 files changed

+80
-41
lines changed

BackendBench/scripts/generate_operator_coverage_csv.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
extract_aten_ops,
1818
extract_operator_name,
1919
)
20-
from BackendBench.opinfo_suite import OpInfoTestSuite
21-
from BackendBench.torchbench_suite import TorchBenchTestSuite
20+
from BackendBench.suite import OpInfoTestSuite, TorchBenchTestSuite
2221

2322

2423
def get_torchbench_ops():

BackendBench/scripts/get_tests_stat.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111
import statistics
1212

1313
import torch
14-
from BackendBench.facto_suite import FactoTestSuite
15-
from BackendBench.opinfo_suite import OpInfoTestSuite
16-
from BackendBench.torchbench_suite import TorchBenchTestSuite
14+
from BackendBench.suite import OpInfoTestSuite, TorchBenchTestSuite, FactoTestSuite
1715
from BackendBench.scripts.pytorch_operators import extract_operator_name
1816

1917

BackendBench/scripts/main.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
import click
1616
import torch
1717

18-
from BackendBench.facto_suite import FactoTestSuite
1918
from BackendBench.llm_client import ClaudeKernelGenerator, LLMKernelGenerator
20-
from BackendBench.opinfo_suite import OpInfoTestSuite
21-
from BackendBench.suite import SmokeTestSuite
22-
from BackendBench.torchbench_suite import DEFAULT_HUGGINGFACE_URL, TorchBenchTestSuite
19+
from BackendBench.suite import (
20+
SmokeTestSuite,
21+
OpInfoTestSuite,
22+
DEFAULT_HUGGINGFACE_URL,
23+
TorchBenchTestSuite,
24+
FactoTestSuite,
25+
)
2326

2427
logger = logging.getLogger(__name__)
2528

BackendBench/suite/__init__.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
BackendBench suites submodule.
9+
10+
This module provides various test suite implementations for benchmarking
11+
PyTorch operations across different backends. Each test suite defines a
12+
collection of tests to evaluate the correctness and/or performacne of
13+
backend implementations by comparing them against PyTorch operations.
14+
"""
15+
16+
from .base import Test, OpTest, TestSuite
17+
from .facto import FactoTestSuite
18+
from .opinfo import OpInfoTestSuite
19+
from .smoke import SmokeTestSuite, randn
20+
from .torchbench import TorchBenchOpTest, TorchBenchTestSuite, DEFAULT_HUGGINGFACE_URL
21+
22+
__all__ = [
23+
"Test",
24+
"OpTest",
25+
"TestSuite",
26+
"FactoTestSuite",
27+
"OpInfoTestSuite",
28+
"SmokeTestSuite",
29+
"randn",
30+
"TorchBenchOpTest",
31+
"TorchBenchTestSuite",
32+
"DEFAULT_HUGGINGFACE_URL",
33+
]

BackendBench/suite.py renamed to BackendBench/suite/base.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,6 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import torch
8-
9-
from BackendBench.opregistry import get_operator
10-
11-
12-
def randn(*args, **kwargs):
13-
return lambda: torch.randn(*args, **kwargs)
14-
157

168
class Test:
179
def __init__(self, *args, **kwargs):
@@ -42,19 +34,3 @@ def __init__(self, name, optests):
4234
def __iter__(self):
4335
for optest in self.optests:
4436
yield optest
45-
46-
47-
SmokeTestSuite = TestSuite(
48-
"smoke",
49-
[
50-
OpTest(
51-
get_operator(torch.ops.aten.relu.default),
52-
[
53-
Test(randn(2, device="cpu")),
54-
],
55-
[
56-
Test(randn(2**28, device="cpu")),
57-
],
58-
)
59-
],
60-
)

BackendBench/facto_suite.py renamed to BackendBench/suite/facto.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
SpecDictDB = None
2121

2222

23-
from .eval import allclose
24-
from .opregistry import get_operator
25-
from .suite import OpTest, TestSuite
23+
from BackendBench.eval import allclose
24+
from BackendBench.opregistry import get_operator
25+
from .base import OpTest, TestSuite
2626

2727
logger = logging.getLogger(__name__)
2828

BackendBench/opinfo_suite.py renamed to BackendBench/suite/opinfo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from torch.testing._internal.common_methods_invocations import op_db
1111
from torch.utils._python_dispatch import TorchDispatchMode
1212

13-
from .eval import allclose
14-
from .suite import OpTest, TestSuite
13+
from BackendBench.eval import allclose
14+
from .base import OpTest, TestSuite
1515

1616
logger = logging.getLogger(__name__)
1717

BackendBench/suite/smoke.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
9+
from BackendBench.opregistry import get_operator
10+
from .base import Test, OpTest, TestSuite
11+
12+
13+
def randn(*args, **kwargs):
14+
return lambda: torch.randn(*args, **kwargs)
15+
16+
17+
SmokeTestSuite = TestSuite(
18+
"smoke",
19+
[
20+
OpTest(
21+
get_operator(torch.ops.aten.relu.default),
22+
[
23+
Test(randn(2, device="cpu")),
24+
],
25+
[
26+
Test(randn(2**28, device="cpu")),
27+
],
28+
)
29+
],
30+
)
File renamed without changes.

test/test_adverse_cases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import pytest
8-
from BackendBench.torchbench_suite import TorchBenchOpTest
8+
from BackendBench.suite import TorchBenchOpTest
99
import BackendBench.multiprocessing_eval as multiprocessing_eval
1010
import BackendBench.backends as backends
1111
import torch

0 commit comments

Comments
 (0)