11import torch
22import infinicore
33import traceback
4-
4+ from dataclasses import dataclass
55from abc import ABC , abstractmethod
6- from typing import List , Dict , Any , Optional
6+ from typing import List , Dict , Any , Optional , Tuple
77
88from .datatypes import to_torch_dtype , to_infinicore_dtype
99from .devices import InfiniDeviceNames , torch_device_map
1515)
1616
1717
18+ @dataclass
19+ class TestResult :
20+ """Test result data structure"""
21+ success : bool
22+ return_code : int # 0: success, -1: failure, -2: skipped, -3: partial
23+ torch_time : float = 0.0
24+ infini_time : float = 0.0
25+ error_message : str = ""
26+ test_case : Any = None
27+ device : Any = None
28+
29+
1830class TestCase :
1931 """Test case with all configuration included"""
2032
@@ -23,11 +35,11 @@ def __init__(
2335 inputs ,
2436 kwargs = None ,
2537 output_spec = None ,
38+ output_specs = None ,
2639 comparison_target = None ,
2740 description = "" ,
2841 tolerance = None ,
2942 output_count = 1 ,
30- output_specs = None ,
3143 ):
3244 """
3345 Initialize a test case with complete configuration
@@ -248,6 +260,8 @@ def __init__(self, test_cases, test_config):
248260 "infinicore_total" : 0.0 ,
249261 "per_test_case" : {}, # Store timing per test case
250262 }
263+ # Store test results
264+ self .test_results = []
251265
252266 def run_tests (self , devices , test_func , test_type = "Test" ):
253267 """
@@ -270,33 +284,47 @@ def run_tests(self, devices, test_func, test_type="Test"):
270284 try :
271285 print (f"{ test_case } " )
272286
273- # Execute test and get result status
274- success , status = test_func (device , test_case , self .config )
287+ # Execute test and get TestResult object
288+ test_result = test_func (device , test_case , self .config )
289+ self .test_results .append (test_result )
275290
276- # Handle different test statuses
277- if status == "passed" :
291+ # Handle different test statuses based on return_code
292+ if test_result . return_code == 0 : # Success
278293 self .passed_tests .append (
279294 f"{ test_case } - { InfiniDeviceNames [device ]} "
280295 )
281296 print (f"\033 [92m✓\033 [0m Passed" )
282- elif status == "skipped" :
283- # Test was skipped due to both operators not being implemented
297+ elif test_result .return_code == - 1 :
298+ fail_msg = f"{ test_case } - { InfiniDeviceNames [device ]} - Test terminated in verbose mode."
299+ self .failed_tests .append (fail_msg )
300+ elif test_result .return_code == - 2 : # Skipped
284301 skip_msg = f"{ test_case } - { InfiniDeviceNames [device ]} - Both operators not implemented"
285302 self .skipped_tests .append (skip_msg )
286- elif status == "partial" :
287- # Test was partially executed (one operator not implemented)
303+ print ( f" \033 [93m⚠ \033 [0m Both operators not implemented - test skipped" )
304+ elif test_result . return_code == - 3 : # Partial
288305 partial_msg = f"{ test_case } - { InfiniDeviceNames [device ]} - One operator not implemented"
289306 self .partial_tests .append (partial_msg )
307+ print (f"\033 [93m⚠\033 [0m One operator not implemented - running single operator without comparison" )
290308
291- # Failed tests are handled in the exception handler below
309+ if self .config .verbose and test_result .return_code != 0 :
310+ return False
292311
293312 except Exception as e :
294313 error_msg = (
295314 f"{ test_case } - { InfiniDeviceNames [device ]} - Error: { e } "
296315 )
297316 print (f"\033 [91m✗\033 [0m { error_msg } " )
298317 self .failed_tests .append (error_msg )
299-
318+
319+ # Create a failed TestResult
320+ failed_result = TestResult (
321+ success = False ,
322+ return_code = - 1 ,
323+ error_message = str (e ),
324+ test_case = test_case ,
325+ device = device
326+ )
327+ self .test_results .append (failed_result )
300328 # In verbose mode, print full traceback and stop execution
301329 if self .config .verbose :
302330 traceback .print_exc ()
@@ -305,8 +333,7 @@ def run_tests(self, devices, test_func, test_type="Test"):
305333 if self .config .debug :
306334 raise
307335
308- # Return True if no tests failed (skipped/partial tests don't count as failures)
309- return len (self .failed_tests ) == 0
336+ return len (self .failed_tests ) == 0 and len (self .skipped_tests ) == 0 and len (self .partial_tests ) == 0
310337
311338 def print_summary (self ):
312339 """
@@ -377,6 +404,10 @@ def _print_benchmark_summary(self):
377404 )
378405 print (f"Speedup (PyTorch/InfiniCore): { speedup :.2f} x" )
379406
407+ def get_test_results (self ):
408+ """Get all test results"""
409+ return self .test_results
410+
380411
381412class BaseOperatorTest (ABC ):
382413 """Base operator test"""
@@ -480,11 +511,17 @@ def run_test(self, device, test_case, config):
480511 config: Test configuration
481512
482513 Returns:
483- tuple: (success, status) where:
484- success: bool indicating if test passed
485- status: str describing test status ("passed", "skipped", "partial")
514+ TestResult: Test result object containing status and timing information
486515 """
487516 device_str = torch_device_map [device ]
517+
518+ # Initialize test result
519+ test_result = TestResult (
520+ success = False ,
521+ return_code = - 1 , # Default to failure
522+ test_case = test_case ,
523+ device = device
524+ )
488525
489526 # Prepare inputs and kwargs with actual tensors
490527 inputs , kwargs = self .prepare_inputs_and_kwargs (test_case , device )
@@ -559,7 +596,10 @@ def run_test(self, device, test_case, config):
559596 except NotImplementedError :
560597 if config .verbose :
561598 traceback .print_exc ()
562- return False # Stop test execution immediately
599+ # Return test result immediately in verbose mode
600+ test_result .return_code = - 1
601+ test_result .error_message = "torch_operator not implemented"
602+ return test_result
563603 torch_implemented = False
564604 torch_result = None
565605
@@ -570,26 +610,24 @@ def run_test(self, device, test_case, config):
570610 except NotImplementedError :
571611 if config .verbose :
572612 traceback .print_exc ()
573- return False # Stop test execution immediately
613+ # Return test result immediately in verbose mode
614+ test_result .return_code = - 1
615+ test_result .error_message = "infinicore_operator not implemented"
616+ return test_result
574617 infini_implemented = False
575618 infini_result = None
576619
577620 # Skip if neither operator is implemented
578621 if not torch_implemented and not infini_implemented :
579- print ( f" \033 [93m⚠ \033 [0m Both operators not implemented - test skipped" )
580- return False , "skipped"
622+ test_result . return_code = - 2 # Skipped
623+ return test_result
581624
582625 # Single operator execution without comparison
583626 if not torch_implemented or not infini_implemented :
584- missing_op = (
585- "torch_operator" if not torch_implemented else "infinicore_operator"
586- )
587- print (
588- f"\033 [93m⚠\033 [0m { missing_op } not implemented - running single operator without comparison"
589- )
590-
627+ test_result .return_code = - 3 # Partial
628+ # Run benchmarking for partial tests if enabled
591629 if config .bench :
592- self ._run_benchmarking (
630+ torch_time , infini_time = self ._run_benchmarking (
593631 config ,
594632 device_str ,
595633 torch_implemented ,
@@ -601,8 +639,9 @@ def run_test(self, device, test_case, config):
601639 test_case .output_count ,
602640 comparison_target ,
603641 )
604- return False , "partial"
605-
642+ test_result .torch_time = torch_time
643+ test_result .infini_time = infini_time
644+ return test_result
606645 # ==========================================================================
607646 # MULTIPLE OUTPUTS COMPARISON LOGIC
608647 # ==========================================================================
@@ -711,7 +750,7 @@ def run_test(self, device, test_case, config):
711750 # UNIFIED BENCHMARKING LOGIC
712751 # ==========================================================================
713752 if config .bench :
714- self ._run_benchmarking (
753+ torch_time , infini_time = self ._run_benchmarking (
715754 config ,
716755 device_str ,
717756 True ,
@@ -723,9 +762,13 @@ def run_test(self, device, test_case, config):
723762 test_case .output_count ,
724763 comparison_target ,
725764 )
765+ test_result .torch_time = torch_time
766+ test_result .infini_time = infini_time
726767
727768 # Test passed successfully
728- return True , "passed"
769+ test_result .success = True
770+ test_result .return_code = 0
771+ return test_result
729772
730773 def _run_benchmarking (
731774 self ,
@@ -742,8 +785,10 @@ def _run_benchmarking(
742785 ):
743786 """
744787 Unified benchmarking logic with timing accumulation
745- """
746788
789+ Returns:
790+ tuple: (torch_time, infini_time) timing results
791+ """
747792 # Initialize timing variables
748793 torch_time = 0.0
749794 infini_time = 0.0
@@ -809,3 +854,5 @@ def infini_op():
809854 # Accumulate total times
810855 config ._test_runner .benchmark_times ["torch_total" ] += torch_time
811856 config ._test_runner .benchmark_times ["infinicore_total" ] += infini_time
857+
858+ return torch_time , infini_time
0 commit comments