Skip to content

Commit 6662abe

Browse files
committed
issue/598 - optimize run.py performance
1 parent 30d6a72 commit 6662abe

File tree

4 files changed

+300
-175
lines changed

4 files changed

+300
-175
lines changed

test/infinicore/framework/base.py

Lines changed: 82 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import torch
22
import infinicore
33
import traceback
4-
4+
from dataclasses import dataclass
55
from abc import ABC, abstractmethod
6-
from typing import List, Dict, Any, Optional
6+
from typing import List, Dict, Any, Optional, Tuple
77

88
from .datatypes import to_torch_dtype, to_infinicore_dtype
99
from .devices import InfiniDeviceNames, torch_device_map
@@ -15,6 +15,18 @@
1515
)
1616

1717

18+
@dataclass
19+
class TestResult:
20+
"""Test result data structure"""
21+
success: bool
22+
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
23+
torch_time: float = 0.0
24+
infini_time: float = 0.0
25+
error_message: str = ""
26+
test_case: Any = None
27+
device: Any = None
28+
29+
1830
class TestCase:
1931
"""Test case with all configuration included"""
2032

@@ -23,11 +35,11 @@ def __init__(
2335
inputs,
2436
kwargs=None,
2537
output_spec=None,
38+
output_specs=None,
2639
comparison_target=None,
2740
description="",
2841
tolerance=None,
2942
output_count=1,
30-
output_specs=None,
3143
):
3244
"""
3345
Initialize a test case with complete configuration
@@ -248,6 +260,8 @@ def __init__(self, test_cases, test_config):
248260
"infinicore_total": 0.0,
249261
"per_test_case": {}, # Store timing per test case
250262
}
263+
# Store test results
264+
self.test_results = []
251265

252266
def run_tests(self, devices, test_func, test_type="Test"):
253267
"""
@@ -270,33 +284,47 @@ def run_tests(self, devices, test_func, test_type="Test"):
270284
try:
271285
print(f"{test_case}")
272286

273-
# Execute test and get result status
274-
success, status = test_func(device, test_case, self.config)
287+
# Execute test and get TestResult object
288+
test_result = test_func(device, test_case, self.config)
289+
self.test_results.append(test_result)
275290

276-
# Handle different test statuses
277-
if status == "passed":
291+
# Handle different test statuses based on return_code
292+
if test_result.return_code == 0: # Success
278293
self.passed_tests.append(
279294
f"{test_case} - {InfiniDeviceNames[device]}"
280295
)
281296
print(f"\033[92m✓\033[0m Passed")
282-
elif status == "skipped":
283-
# Test was skipped due to both operators not being implemented
297+
elif test_result.return_code == -1:
298+
fail_msg = f"{test_case} - {InfiniDeviceNames[device]} - Test terminated in verbose mode."
299+
self.failed_tests.append(fail_msg)
300+
elif test_result.return_code == -2: # Skipped
284301
skip_msg = f"{test_case} - {InfiniDeviceNames[device]} - Both operators not implemented"
285302
self.skipped_tests.append(skip_msg)
286-
elif status == "partial":
287-
# Test was partially executed (one operator not implemented)
303+
print(f"\033[93m⚠\033[0m Both operators not implemented - test skipped")
304+
elif test_result.return_code == -3: # Partial
288305
partial_msg = f"{test_case} - {InfiniDeviceNames[device]} - One operator not implemented"
289306
self.partial_tests.append(partial_msg)
307+
print(f"\033[93m⚠\033[0m One operator not implemented - running single operator without comparison")
290308

291-
# Failed tests are handled in the exception handler below
309+
if self.config.verbose and test_result.return_code != 0:
310+
return False
292311

293312
except Exception as e:
294313
error_msg = (
295314
f"{test_case} - {InfiniDeviceNames[device]} - Error: {e}"
296315
)
297316
print(f"\033[91m✗\033[0m {error_msg}")
298317
self.failed_tests.append(error_msg)
299-
318+
319+
# Create a failed TestResult
320+
failed_result = TestResult(
321+
success=False,
322+
return_code=-1,
323+
error_message=str(e),
324+
test_case=test_case,
325+
device=device
326+
)
327+
self.test_results.append(failed_result)
300328
# In verbose mode, print full traceback and stop execution
301329
if self.config.verbose:
302330
traceback.print_exc()
@@ -305,8 +333,7 @@ def run_tests(self, devices, test_func, test_type="Test"):
305333
if self.config.debug:
306334
raise
307335

308-
# Return True if no tests failed (skipped/partial tests don't count as failures)
309-
return len(self.failed_tests) == 0
336+
return len(self.failed_tests) == 0 and len(self.skipped_tests) == 0 and len(self.partial_tests) == 0
310337

311338
def print_summary(self):
312339
"""
@@ -377,6 +404,10 @@ def _print_benchmark_summary(self):
377404
)
378405
print(f"Speedup (PyTorch/InfiniCore): {speedup:.2f}x")
379406

407+
def get_test_results(self):
408+
"""Get all test results"""
409+
return self.test_results
410+
380411

381412
class BaseOperatorTest(ABC):
382413
"""Base operator test"""
@@ -480,11 +511,17 @@ def run_test(self, device, test_case, config):
480511
config: Test configuration
481512
482513
Returns:
483-
tuple: (success, status) where:
484-
success: bool indicating if test passed
485-
status: str describing test status ("passed", "skipped", "partial")
514+
TestResult: Test result object containing status and timing information
486515
"""
487516
device_str = torch_device_map[device]
517+
518+
# Initialize test result
519+
test_result = TestResult(
520+
success=False,
521+
return_code=-1, # Default to failure
522+
test_case=test_case,
523+
device=device
524+
)
488525

489526
# Prepare inputs and kwargs with actual tensors
490527
inputs, kwargs = self.prepare_inputs_and_kwargs(test_case, device)
@@ -559,7 +596,10 @@ def run_test(self, device, test_case, config):
559596
except NotImplementedError:
560597
if config.verbose:
561598
traceback.print_exc()
562-
return False # Stop test execution immediately
599+
# Return test result immediately in verbose mode
600+
test_result.return_code = -1
601+
test_result.error_message = "torch_operator not implemented"
602+
return test_result
563603
torch_implemented = False
564604
torch_result = None
565605

@@ -570,26 +610,24 @@ def run_test(self, device, test_case, config):
570610
except NotImplementedError:
571611
if config.verbose:
572612
traceback.print_exc()
573-
return False # Stop test execution immediately
613+
# Return test result immediately in verbose mode
614+
test_result.return_code = -1
615+
test_result.error_message = "infinicore_operator not implemented"
616+
return test_result
574617
infini_implemented = False
575618
infini_result = None
576619

577620
# Skip if neither operator is implemented
578621
if not torch_implemented and not infini_implemented:
579-
print(f"\033[93m⚠\033[0m Both operators not implemented - test skipped")
580-
return False, "skipped"
622+
test_result.return_code = -2 # Skipped
623+
return test_result
581624

582625
# Single operator execution without comparison
583626
if not torch_implemented or not infini_implemented:
584-
missing_op = (
585-
"torch_operator" if not torch_implemented else "infinicore_operator"
586-
)
587-
print(
588-
f"\033[93m⚠\033[0m {missing_op} not implemented - running single operator without comparison"
589-
)
590-
627+
test_result.return_code = -3 # Partial
628+
# Run benchmarking for partial tests if enabled
591629
if config.bench:
592-
self._run_benchmarking(
630+
torch_time, infini_time = self._run_benchmarking(
593631
config,
594632
device_str,
595633
torch_implemented,
@@ -601,8 +639,9 @@ def run_test(self, device, test_case, config):
601639
test_case.output_count,
602640
comparison_target,
603641
)
604-
return False, "partial"
605-
642+
test_result.torch_time = torch_time
643+
test_result.infini_time = infini_time
644+
return test_result
606645
# ==========================================================================
607646
# MULTIPLE OUTPUTS COMPARISON LOGIC
608647
# ==========================================================================
@@ -711,7 +750,7 @@ def run_test(self, device, test_case, config):
711750
# UNIFIED BENCHMARKING LOGIC
712751
# ==========================================================================
713752
if config.bench:
714-
self._run_benchmarking(
753+
torch_time, infini_time = self._run_benchmarking(
715754
config,
716755
device_str,
717756
True,
@@ -723,9 +762,13 @@ def run_test(self, device, test_case, config):
723762
test_case.output_count,
724763
comparison_target,
725764
)
765+
test_result.torch_time = torch_time
766+
test_result.infini_time = infini_time
726767

727768
# Test passed successfully
728-
return True, "passed"
769+
test_result.success = True
770+
test_result.return_code = 0
771+
return test_result
729772

730773
def _run_benchmarking(
731774
self,
@@ -742,8 +785,10 @@ def _run_benchmarking(
742785
):
743786
"""
744787
Unified benchmarking logic with timing accumulation
745-
"""
746788
789+
Returns:
790+
tuple: (torch_time, infini_time) timing results
791+
"""
747792
# Initialize timing variables
748793
torch_time = 0.0
749794
infini_time = 0.0
@@ -809,3 +854,5 @@ def infini_op():
809854
# Accumulate total times
810855
config._test_runner.benchmark_times["torch_total"] += torch_time
811856
config._test_runner.benchmark_times["infinicore_total"] += infini_time
857+
858+
return torch_time, infini_time

test/infinicore/framework/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,9 @@ def get_args():
100100

101101
# Device options using shared hardware info
102102
hardware_group = get_hardware_args_group(parser)
103+
args, unknown = parser.parse_known_args()
103104

104-
return parser.parse_args()
105+
return args
105106

106107

107108
def get_test_devices(args):

test/infinicore/framework/runner.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ def run(self):
2121
"""Execute the complete test suite
2222
2323
Returns:
24-
bool: True if all tests passed or were skipped/partial, False if any tests failed
24+
tuple: (success, test_runner) where:
25+
success: bool indicating if all tests passed or were skipped/partial
26+
test_runner: TestRunner instance with test results
2527
"""
2628
config = TestConfig(
2729
debug=self.args.debug,
@@ -51,7 +53,7 @@ def run(self):
5153
# Both conditions must be True for overall success
5254
# - has_no_failures: no test failures during execution
5355
# - summary_passed: summary confirms no failures
54-
return has_no_failures and summary_passed
56+
return (has_no_failures and summary_passed), runner
5557

5658
def run_and_exit(self):
5759
"""Run tests and exit with appropriate status code
@@ -60,5 +62,5 @@ def run_and_exit(self):
6062
0: All tests passed or were skipped/partial (no failures)
6163
1: One or more tests failed
6264
"""
63-
success = self.run()
65+
success, runner = self.run()
6466
sys.exit(0 if success else 1)

0 commit comments

Comments
 (0)