Skip to content

Commit 12cde8e

Browse files
authored
Merge pull request #788 from InfiniTensor/issue/787
issue/787 - Split run ops test logic and fix kwargs name in report
2 parents 62fe699 + 7aece93 commit 12cde8e

File tree

19 files changed

+1068
-1218
lines changed

19 files changed

+1068
-1218
lines changed

test/infinicore/framework/__init__.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .base import TestConfig, TestRunner, BaseOperatorTest
2-
from .test_case import TestCase, TestResult
2+
from .entities import TestCase
33
from .benchmark import BenchmarkUtils, BenchmarkResult
44
from .config import (
55
add_common_test_args,
@@ -9,35 +9,44 @@
99
)
1010
from .datatypes import to_torch_dtype, to_infinicore_dtype
1111
from .devices import InfiniDeviceEnum, InfiniDeviceNames, torch_device_map
12+
from .results import TestTiming, OperatorResult, CaseResult, TestSummary
1213
from .runner import GenericTestRunner
1314
from .tensor import TensorSpec, TensorInitializer
14-
from .utils import (
15+
from .executor import TestExecutor
16+
from .utils.compare_utils import (
1517
compare_results,
1618
create_test_comparator,
1719
debug,
1820
get_tolerance,
21+
)
22+
from .utils.json_utils import save_json_report
23+
from .utils.tensor_utils import (
1924
infinicore_tensor_from_torch,
20-
rearrange_tensor,
2125
convert_infinicore_to_torch,
26+
rearrange_tensor,
2227
is_broadcast,
28+
is_integer_dtype,
2329
is_complex_dtype,
2430
is_floating_dtype,
25-
is_integer_dtype,
2631
)
2732

33+
2834
__all__ = [
2935
# Core types and classes
3036
"BaseOperatorTest",
37+
"CaseResult",
3138
"GenericTestRunner",
3239
"InfiniDeviceEnum",
3340
"InfiniDeviceNames",
41+
"OperatorResult",
3442
"TensorInitializer",
3543
"TensorSpec",
3644
"TestCase",
3745
"TestConfig",
38-
"TestResult",
46+
"TestExecutor",
47+
"TestSummary",
3948
"TestRunner",
40-
"TestReporter",
49+
"TestTiming",
4150
# Core functions
4251
"add_common_test_args",
4352
"compare_results",
@@ -50,6 +59,8 @@
5059
"get_tolerance",
5160
"infinicore_tensor_from_torch",
5261
"rearrange_tensor",
62+
# Json utilites
63+
"save_json_report",
5364
# Utility functions
5465
"to_infinicore_dtype",
5566
"to_torch_dtype",

test/infinicore/framework/base.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88
import traceback
99
from abc import ABC, abstractmethod
1010

11-
from .test_case import TestCase, TestResult
11+
from .results import CaseResult
1212
from .datatypes import to_torch_dtype, to_infinicore_dtype
1313
from .devices import InfiniDeviceNames, torch_device_map
1414
from .tensor import TensorSpec, TensorInitializer
15-
from .utils import (
15+
from .utils.tensor_utils import (
1616
clone_torch_tensor,
17-
create_test_comparator,
1817
infinicore_tensor_from_torch,
1918
)
19+
from .utils.compare_utils import create_test_comparator
2020
from .benchmark import BenchmarkUtils
2121

2222

@@ -84,7 +84,7 @@ def run_tests(self, devices, test_func, test_type="Test"):
8484
try:
8585
print(f"{test_case}")
8686

87-
# Execute test and get TestResult object
87+
# Execute test and get CaseResult object
8888
test_result = test_func(device, test_case, self.config)
8989
self.test_results.append(test_result)
9090

@@ -118,8 +118,8 @@ def run_tests(self, devices, test_func, test_type="Test"):
118118
print(f"\033[91m✗\033[0m {error_msg}")
119119
self.failed_tests.append(error_msg)
120120

121-
# Create a failed TestResult
122-
failed_result = TestResult(
121+
# Create a failed CaseResult
122+
failed_result = CaseResult(
123123
success=False,
124124
return_code=-1,
125125
error_message=str(e),
@@ -400,12 +400,12 @@ def run_test(self, device, test_case, config):
400400
config: Test configuration
401401
402402
Returns:
403-
TestResult: Test result object containing status and timing information
403+
CaseResult: Test case result object containing status and timing information
404404
"""
405405
device_str = torch_device_map[device]
406406

407-
# Initialize test result
408-
test_result = TestResult(
407+
# Initialize test case result
408+
test_result = CaseResult(
409409
success=False,
410410
return_code=-1, # Default to failure
411411
test_case=test_case,

test/infinicore/framework/benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
import torch
77
import infinicore
8-
from .utils import synchronize_device
8+
from .utils.tensor_utils import synchronize_device
99

1010

1111
class BenchmarkUtils:
Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,6 @@
77
from .tensor import TensorSpec
88

99

10-
@dataclass
11-
class TestResult:
12-
"""Test result data structure"""
13-
14-
success: bool
15-
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
16-
torch_host_time: float = 0.0
17-
torch_device_time: float = 0.0
18-
infini_host_time: float = 0.0
19-
infini_device_time: float = 0.0
20-
error_message: str = ""
21-
test_case: Any = None
22-
device: Any = None
23-
24-
2510
class TestCase:
2611
"""Test case with all configuration included"""
2712

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import sys
2+
import importlib.util
3+
from io import StringIO
4+
from contextlib import contextmanager
5+
from .results import OperatorResult, TestSummary
6+
7+
8+
@contextmanager
9+
def capture_output():
10+
"""Context manager: captures stdout and stderr."""
11+
new_out, new_err = StringIO(), StringIO()
12+
old_out, old_err = sys.stdout, sys.stderr
13+
try:
14+
sys.stdout, sys.stderr = new_out, new_err
15+
yield new_out, new_err
16+
finally:
17+
sys.stdout, sys.stderr = old_out, old_err
18+
19+
20+
class TestExecutor:
21+
def execute(self, file_path) -> OperatorResult:
22+
result = OperatorResult(name=file_path.stem)
23+
24+
try:
25+
# 1. Dynamically import the module
26+
module = self._import_module(file_path)
27+
28+
# 2. Look for TestRunner
29+
if not hasattr(module, "GenericTestRunner"):
30+
raise ImportError("No GenericTestRunner found in module")
31+
32+
# 3. Look for TestClass (subclass of BaseOperatorTest)
33+
test_class = self._find_test_class(module)
34+
if not test_class:
35+
raise ImportError("No BaseOperatorTest subclass found")
36+
37+
test_instance = test_class()
38+
runner_class = module.GenericTestRunner
39+
runner = runner_class(test_instance.__class__)
40+
41+
# 4. Execute and capture output
42+
with capture_output() as (out, err):
43+
success, internal_runner = runner.run()
44+
45+
# 5. Populate results
46+
result.success = success
47+
result.stdout = out.getvalue()
48+
result.stderr = err.getvalue()
49+
50+
# Extract detailed results from internal_runner
51+
test_results = internal_runner.get_test_results() if internal_runner else []
52+
53+
test_summary = TestSummary()
54+
test_summary.process_operator_result(result, test_results)
55+
56+
except Exception as e:
57+
result.success = False
58+
result.error_message = str(e)
59+
result.stderr += f"\nExecutor Error: {str(e)}"
60+
result.return_code = -1
61+
62+
return result
63+
64+
def _import_module(self, path):
65+
module_name = f"op_test_{path.stem}"
66+
spec = importlib.util.spec_from_file_location(module_name, path)
67+
if not spec or not spec.loader:
68+
raise ImportError(f"Could not load spec from {path}")
69+
module = importlib.util.module_from_spec(spec)
70+
sys.modules[module_name] = module
71+
spec.loader.exec_module(module)
72+
return module
73+
74+
def _find_test_class(self, module):
75+
for attr_name in dir(module):
76+
attr = getattr(module, attr_name)
77+
if isinstance(attr, type) and hasattr(attr, "__bases__"):
78+
# Simple check for base class name
79+
if any("BaseOperatorTest" in str(b) for b in attr.__bases__):
80+
return attr
81+
return None

0 commit comments

Comments
 (0)