Skip to content

Commit 23610fc

Browse files
authored
[test] add unified jit importer test config (#4083)
* make test config more clear
1 parent daad812 commit 23610fc

File tree

5 files changed

+11
-120
lines changed

5 files changed

+11
-120
lines changed

projects/pt1/e2e_testing/main.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,11 @@
1919
# Available test configs.
2020
from torch_mlir_e2e_test.configs import (
2121
LazyTensorCoreTestConfig,
22-
LinalgOnTensorsBackendTestConfig,
23-
StablehloBackendTestConfig,
2422
NativeTorchTestConfig,
2523
OnnxBackendTestConfig,
2624
TorchScriptTestConfig,
27-
TosaBackendTestConfig,
2825
TorchDynamoTestConfig,
26+
JITImporterTestConfig,
2927
FxImporterTestConfig,
3028
)
3129

@@ -152,15 +150,15 @@ def main():
152150

153151
# Find the selected config.
154152
if args.config == "linalg":
155-
config = LinalgOnTensorsBackendTestConfig(RefBackendLinalgOnTensorsBackend())
153+
config = JITImporterTestConfig(RefBackendLinalgOnTensorsBackend())
156154
xfail_set = LINALG_XFAIL_SET
157155
crashing_set = LINALG_CRASHING_SET
158156
elif args.config == "stablehlo":
159-
config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend())
157+
config = JITImporterTestConfig(LinalgOnTensorsStablehloBackend(), "stablehlo")
160158
xfail_set = all_test_unique_names - STABLEHLO_PASS_SET
161159
crashing_set = STABLEHLO_CRASHING_SET
162160
elif args.config == "tosa":
163-
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
161+
config = JITImporterTestConfig(LinalgOnTensorsTosaBackend(), "tosa")
164162
xfail_set = all_test_unique_names - TOSA_PASS_SET
165163
crashing_set = TOSA_CRASHING_SET
166164
elif args.config == "native_torch":

projects/pt1/python/torch_mlir_e2e_test/configs/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
# Also available under a BSD-style license. See LICENSE.
55

66
from .lazy_tensor_core import LazyTensorCoreTestConfig
7-
from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
87
from .native_torch import NativeTorchTestConfig
98
from .onnx_backend import OnnxBackendTestConfig
109
from .torchscript import TorchScriptTestConfig
11-
from .stablehlo_backend import StablehloBackendTestConfig
12-
from .tosa_backend import TosaBackendTestConfig
1310
from .torchdynamo import TorchDynamoTestConfig
11+
from .jit_importer_backend import JITImporterTestConfig
1412
from .fx_importer_backend import FxImporterTestConfig

projects/pt1/python/torch_mlir_e2e_test/configs/stablehlo_backend.py renamed to projects/pt1/python/torch_mlir_e2e_test/configs/jit_importer_backend.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,27 @@
88
import torch
99
from torch_mlir import torchscript
1010

11-
from torch_mlir_e2e_test.stablehlo_backends.abc import StablehloBackend
1211
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
1312
from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders
13+
1414
from .utils import (
1515
recursively_convert_to_numpy,
1616
recursively_convert_from_numpy,
1717
)
1818

1919

20-
class StablehloBackendTestConfig(TestConfig):
21-
"""Base class for TestConfig's that are implemented with StableHLO.
22-
23-
This class handles all the common lowering that torch-mlir does before
24-
reaching the StableHLO abstraction level.
25-
"""
20+
class JITImporterTestConfig(TestConfig):
21+
"""TestConfig that runs the torch.nn.Module with JIT Importer"""
2622

27-
def __init__(self, backend: StablehloBackend):
23+
def __init__(self, backend, output_type="linalg-on-tensors"):
2824
super().__init__()
2925
self.backend = backend
26+
self.output_type = output_type
3027

3128
def compile(self, program: torch.nn.Module, verbose: bool = False) -> Any:
3229
example_args = convert_annotations_to_placeholders(program.forward)
3330
module = torchscript.compile(
34-
program, example_args, output_type="stablehlo", verbose=verbose
31+
program, example_args, output_type=self.output_type, verbose=verbose
3532
)
3633

3734
return self.backend.compile(module)

projects/pt1/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py

Lines changed: 0 additions & 52 deletions
This file was deleted.

0 commit comments

Comments
 (0)