Skip to content

Commit 32bd2f8

Browse files
committed
Issue/787 - class/method rename && move structs code to results file
1 parent 515b92f commit 32bd2f8

File tree

13 files changed

+87
-92
lines changed

13 files changed

+87
-92
lines changed

test/infinicore/framework/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,10 @@
99
)
1010
from .datatypes import to_torch_dtype, to_infinicore_dtype
1111
from .devices import InfiniDeviceEnum, InfiniDeviceNames, torch_device_map
12+
from .results import TestTiming, OperatorResult, CaseResult, TestSummary
1213
from .runner import GenericTestRunner
1314
from .tensor import TensorSpec, TensorInitializer
14-
from .structs import TestTiming, OperatorResult, CaseResult
15-
from .summary import TestSummary
16-
from .driver import TestDriver
15+
from .executor import TestExecutor
1716
from .utils.compare_utils import (
1817
compare_results,
1918
create_test_comparator,
@@ -44,7 +43,7 @@
4443
"TensorSpec",
4544
"TestCase",
4645
"TestConfig",
47-
"TestDriver",
46+
"TestExecutor",
4847
"TestSummary",
4948
"TestRunner",
5049
"TestTiming",

test/infinicore/framework/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import traceback
99
from abc import ABC, abstractmethod
1010

11-
from .structs import CaseResult
11+
from .results import CaseResult
1212
from .datatypes import to_torch_dtype, to_infinicore_dtype
1313
from .devices import InfiniDeviceNames, torch_device_map
1414
from .tensor import TensorSpec, TensorInitializer
Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
import importlib.util
33
from io import StringIO
44
from contextlib import contextmanager
5-
from .structs import OperatorResult
6-
from .summary import TestSummary
5+
from .results import OperatorResult, TestSummary
76

87

98
@contextmanager
@@ -18,8 +17,8 @@ def capture_output():
1817
sys.stdout, sys.stderr = old_out, old_err
1918

2019

21-
class TestDriver:
22-
def drive(self, file_path) -> OperatorResult:
20+
class TestExecutor:
21+
def execute(self, file_path) -> OperatorResult:
2322
result = OperatorResult(name=file_path.stem)
2423

2524
try:
@@ -53,7 +52,6 @@ def drive(self, file_path) -> OperatorResult:
5352

5453
test_summary = TestSummary()
5554
test_summary.process_operator_result(result, test_results)
56-
# test_summary._extract_timing(result, test_results)
5755

5856
except Exception as e:
5957
result.success = False
Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,71 @@
11
from typing import List, Dict, Any
2-
from dataclasses import is_dataclass
2+
from dataclasses import dataclass, is_dataclass, field
33
from .devices import InfiniDeviceEnum
4-
from .base import TensorSpec
4+
from .tensor import TensorSpec
55
from .utils.json_utils import save_json_report
66

7+
@dataclass
8+
class CaseResult:
9+
"""Test case result data structure"""
10+
11+
success: bool
12+
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
13+
torch_host_time: float = 0.0
14+
torch_device_time: float = 0.0
15+
infini_host_time: float = 0.0
16+
infini_device_time: float = 0.0
17+
error_message: str = ""
18+
test_case: Any = None
19+
device: Any = None
20+
21+
22+
@dataclass
23+
class TestTiming:
24+
"""Stores performance timing metrics."""
25+
26+
torch_host: float = 0.0
27+
torch_device: float = 0.0
28+
infini_host: float = 0.0
29+
infini_device: float = 0.0
30+
# Added field to support the logic in your print_summary
31+
operators_tested: int = 0
32+
33+
34+
@dataclass
35+
class OperatorResult:
36+
"""Stores the execution results of a single operator."""
37+
38+
name: str
39+
success: bool = False
40+
return_code: int = -1
41+
error_message: str = ""
42+
stdout: str = ""
43+
stderr: str = ""
44+
timing: TestTiming = field(default_factory=TestTiming)
45+
46+
@property
47+
def status_icon(self):
48+
if self.return_code == 0:
49+
return "✅"
50+
if self.return_code == -2:
51+
return "⏭️"
52+
if self.return_code == -3:
53+
return "⚠️"
54+
return "❌"
55+
56+
@property
57+
def status_text(self):
58+
if self.return_code == 0:
59+
return "PASSED"
60+
if self.return_code == -2:
61+
return "SKIPPED"
62+
if self.return_code == -3:
63+
return "PARTIAL"
64+
return "FAILED"
765

866
class TestSummary:
967
"""
10-
Test summary:
68+
Test Summary class:
1169
1. Aggregates results (Timing & Status calculation).
1270
2. Handles Console Output (Live & Summary).
1371
3. Handles File Reporting (Data Preparation).

test/infinicore/framework/runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import inspect
88
import re
99
from . import TestConfig, TestRunner, get_args, get_test_devices
10-
from .summary import TestSummary
10+
from .results import TestSummary
1111

1212

1313
class GenericTestRunner:

test/infinicore/framework/structs.py

Lines changed: 0 additions & 62 deletions
This file was deleted.

test/infinicore/framework/utils/json_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def save_json_report(save_path, total_results):
1919
print(f"💾 Saving to: {final_path}")
2020

2121
# Helper for JSON stringify to avoid repetition
22-
def _j(obj):
22+
def _to_json(obj):
2323
return json.dumps(obj, ensure_ascii=False)
2424

2525
try:
@@ -58,16 +58,16 @@ def _j(obj):
5858
)
5959

6060
if c_key in ["kwargs", "inputs"]:
61-
_write_smart_field(
61+
_write_field(
6262
f, c_key, c_val, I16, I20, close_comma=c_comma
6363
)
6464
else:
65-
f.write(f'{I16}"{c_key}": {_j(c_val)}{c_comma}\n')
65+
f.write(f'{I16}"{c_key}": {_to_json(c_val)}{c_comma}\n')
6666

6767
# Handle trailing comparison/tolerance fields uniformly
6868
if "comparison_target" in case_item:
69-
cmp = _j(case_item.get("comparison_target"))
70-
tol = _j(case_item.get("tolerance"))
69+
cmp = _to_json(case_item.get("comparison_target"))
70+
tol = _to_json(case_item.get("tolerance"))
7171
f.write(
7272
f'{I16}"comparison_target": {cmp}, "tolerance": {tol}\n'
7373
)
@@ -77,7 +77,7 @@ def _j(obj):
7777
f.write(f"{I8}]{comma}\n")
7878
else:
7979
# Standard top-level fields
80-
f.write(f"{I8}{_j(key)}: {_j(val)}{comma}\n")
80+
f.write(f"{I8}{_to_json(key)}: {_to_json(val)}{comma}\n")
8181

8282
close_entry = "}," if i < len(total_results) - 1 else "}"
8383
f.write(f"{I4}{close_entry}\n")
@@ -90,9 +90,9 @@ def _j(obj):
9090
print(f" ❌ Save failed: {e}")
9191

9292

93-
def _write_smart_field(f, key, value, indent, sub_indent, close_comma=""):
93+
def _write_field(f, key, value, indent, sub_indent, close_comma=""):
9494
"""
95-
Internal Helper: Write a JSON field with smart wrapping.
95+
Internal Helper: Write a JSON field with wrapping.
9696
"""
9797
# 1. Try Compact Mode
9898
compact_json = json.dumps(value, ensure_ascii=False)

test/infinicore/ops/adaptive_max_pool2d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import infinicore
88
from framework import (
99
BaseOperatorTest,
10+
CaseResult,
1011
TensorSpec,
1112
TestCase,
1213
GenericTestRunner,

test/infinicore/ops/embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
from framework import BaseOperatorTest, TensorSpec, TestCase, GenericTestRunner
88
from framework.tensor import TensorInitializer
9-
from framework.utils import (
9+
from framework.utils.tensor_utils import (
1010
convert_infinicore_to_torch,
1111
infinicore_tensor_from_torch,
1212
to_torch_dtype,

test/infinicore/ops/random_sample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,8 @@ def run_test(self, device, test_case, config):
222222

223223
# Re-run operations with the same logits to get results for comparison
224224
# prepare_pytorch_inputs_and_kwargs will reuse self._current_logits if it exists
225-
from framework.base import CaseResult
226-
from framework.utils import (
225+
from framework.results import CaseResult
226+
from framework.utils.tensor_utils import (
227227
convert_infinicore_to_torch,
228228
infinicore_tensor_from_torch,
229229
)

0 commit comments

Comments
 (0)