1111from .utils import (
1212 create_test_comparator ,
1313 infinicore_tensor_from_torch ,
14- profile_operation ,
1514)
15+ from .benchmark import BenchmarkUtils
1616
1717
1818@dataclass
@@ -21,8 +21,10 @@ class TestResult:
2121
2222 success : bool
2323 return_code : int # 0: success, -1: failure, -2: skipped, -3: partial
24- torch_time : float = 0.0
25- infini_time : float = 0.0
24+ torch_host_time : float = 0.0
25+ torch_device_time : float = 0.0
26+ infini_host_time : float = 0.0
27+ infini_device_time : float = 0.0
2628 error_message : str = ""
2729 test_case : Any = None
2830 device : Any = None
@@ -202,8 +204,10 @@ def __init__(self, test_cases, test_config):
202204 ) # Track passed tests (both operators implemented and passed)
203205 # Add benchmark timing statistics
204206 self .benchmark_times = {
205- "torch_total" : 0.0 ,
206- "infinicore_total" : 0.0 ,
207+ "torch_host_total" : 0.0 ,
208+ "torch_device_total" : 0.0 ,
209+ "infinicore_host_total" : 0.0 ,
210+ "infinicore_device_total" : 0.0 ,
207211 "per_test_case" : {}, # Store timing per test case
208212 }
209213 # Store test results
@@ -329,8 +333,10 @@ def print_summary(self):
329333
330334 # Print benchmark summary if benchmarking was enabled
331335 if self .config .bench and (
332- self .benchmark_times ["torch_total" ] > 0
333- or self .benchmark_times ["infinicore_total" ] > 0
336+ self .benchmark_times ["torch_host_total" ] > 0
337+ or self .benchmark_times ["torch_device_total" ] > 0
338+ or self .benchmark_times ["infinicore_host_total" ] > 0
339+ or self .benchmark_times ["infinicore_device_total" ] > 0
334340 ):
335341 self ._print_benchmark_summary ()
336342
@@ -342,19 +348,30 @@ def _print_benchmark_summary(self):
342348 print (f"{ '-' * 60 } " )
343349 print ("BENCHMARK SUMMARY" )
344350
345- torch_total = self .benchmark_times ["torch_total" ]
346- infinicore_total = self .benchmark_times ["infinicore_total" ]
351+ torch_host_total = self .benchmark_times ["torch_host_total" ]
352+ torch_device_total = self .benchmark_times ["torch_device_total" ]
353+ infinicore_host_total = self .benchmark_times ["infinicore_host_total" ]
354+ infinicore_device_total = self .benchmark_times ["infinicore_device_total" ]
355+
356+ if torch_host_total > 0 :
357+ print (f"PyTorch Host Total Time: { torch_host_total * 1000 :.3f} ms" )
358+ if torch_device_total > 0 :
359+ print (f"PyTorch Device Total Time: { torch_device_total * 1000 :.3f} ms" )
360+ if infinicore_host_total > 0 :
361+ print (f"InfiniCore Host Total Time: { infinicore_host_total * 1000 :.3f} ms" )
362+ if infinicore_device_total > 0 :
363+ print (
364+ f"InfiniCore Device Total Time: { infinicore_device_total * 1000 :.3f} ms"
365+ )
347366
348- if torch_total > 0 :
349- print ( f"PyTorch Total Time: { torch_total * 1000 :.3f } ms" )
350- if infinicore_total > 0 :
351- print (f"InfiniCore Total Time : { infinicore_total * 1000 :.3f } ms " )
367+ # Calculate speedups
368+ if torch_host_total > 0 and infinicore_host_total > 0 :
369+ host_speedup = torch_host_total / infinicore_host_total
370+ print (f"Host Speedup (PyTorch/InfiniCore) : { host_speedup :.2f } x " )
352371
353- if torch_total > 0 and infinicore_total > 0 :
354- speedup = (
355- torch_total / infinicore_total if infinicore_total > 0 else float ("inf" )
356- )
357- print (f"Speedup (PyTorch/InfiniCore): { speedup :.2f} x" )
372+ if torch_device_total > 0 and infinicore_device_total > 0 :
373+ device_speedup = torch_device_total / infinicore_device_total
374+ print (f"Device Speedup (PyTorch/InfiniCore): { device_speedup :.2f} x" )
358375
359376 def get_test_results (self ):
360377 """Get all test results"""
@@ -593,20 +610,27 @@ def run_test(self, device, test_case, config):
593610 test_result .return_code = - 3 # Partial
594611 # Run benchmarking for partial tests if enabled
595612 if config .bench :
596- torch_time , infini_time = self ._run_benchmarking (
597- config ,
598- device_str ,
599- torch_implemented ,
600- infini_implemented ,
601- inputs ,
602- kwargs ,
603- infini_inputs ,
604- infini_kwargs ,
605- test_case .output_count ,
606- comparison_target ,
613+ torch_host , torch_device , infini_host , infini_device = (
614+ BenchmarkUtils .run_benchmarking (
615+ config ,
616+ device_str ,
617+ torch_implemented ,
618+ infini_implemented ,
619+ self .torch_operator ,
620+ self .infinicore_operator ,
621+ inputs ,
622+ kwargs ,
623+ infini_inputs ,
624+ infini_kwargs ,
625+ test_case .output_count ,
626+ comparison_target ,
627+ bench_mode = config .bench ,
628+ )
607629 )
608- test_result .torch_time = torch_time
609- test_result .infini_time = infini_time
630+ test_result .torch_host_time = torch_host
631+ test_result .torch_device_time = torch_device
632+ test_result .infini_host_time = infini_host
633+ test_result .infini_device_time = infini_device
610634 return test_result
611635 # ==========================================================================
612636 # MULTIPLE OUTPUTS COMPARISON LOGIC
@@ -716,109 +740,43 @@ def run_test(self, device, test_case, config):
716740 # UNIFIED BENCHMARKING LOGIC
717741 # ==========================================================================
718742 if config .bench :
719- torch_time , infini_time = self ._run_benchmarking (
720- config ,
721- device_str ,
722- True ,
723- True ,
724- inputs ,
725- kwargs ,
726- infini_inputs ,
727- infini_kwargs ,
728- test_case .output_count ,
729- comparison_target ,
743+ torch_host , torch_device , infini_host , infini_device = (
744+ BenchmarkUtils .run_benchmarking (
745+ config ,
746+ device_str ,
747+ True ,
748+ True ,
749+ self .torch_operator ,
750+ self .infinicore_operator ,
751+ inputs ,
752+ kwargs ,
753+ infini_inputs ,
754+ infini_kwargs ,
755+ test_case .output_count ,
756+ comparison_target ,
757+ bench_mode = config .bench ,
758+ )
730759 )
731- test_result .torch_time = torch_time
732- test_result .infini_time = infini_time
760+ test_result .torch_host_time = torch_host
761+ test_result .torch_device_time = torch_device
762+ test_result .infini_host_time = infini_host
763+ test_result .infini_device_time = infini_device
764+
765+ # Store timing information in the test runner
766+ if hasattr (config , "_test_runner" ) and config ._test_runner :
767+ # Accumulate total times
768+ config ._test_runner .benchmark_times ["torch_host_total" ] += torch_host
769+ config ._test_runner .benchmark_times [
770+ "torch_device_total"
771+ ] += torch_device
772+ config ._test_runner .benchmark_times [
773+ "infinicore_host_total"
774+ ] += infini_host
775+ config ._test_runner .benchmark_times [
776+ "infinicore_device_total"
777+ ] += infini_device
733778
734779 # Test passed successfully
735780 test_result .success = True
736781 test_result .return_code = 0
737782 return test_result
738-
739- def _run_benchmarking (
740- self ,
741- config ,
742- device_str ,
743- torch_implemented ,
744- infini_implemented ,
745- inputs ,
746- kwargs ,
747- infini_inputs ,
748- infini_kwargs ,
749- output_count ,
750- comparison_target ,
751- ):
752- """
753- Unified benchmarking logic with timing accumulation
754-
755- Returns:
756- tuple: (torch_time, infini_time) timing results
757- """
758- # Initialize timing variables
759- torch_time = 0.0
760- infini_time = 0.0
761-
762- if torch_implemented :
763- if output_count > 1 :
764- # For multiple outputs, just call the operator
765- def torch_op ():
766- return self .torch_operator (* inputs , ** kwargs )
767-
768- else :
769- if comparison_target is None :
770- # Out-of-place benchmarking
771- def torch_op ():
772- return self .torch_operator (* inputs , ** kwargs )
773-
774- else :
775- # In-place benchmarking
776- def torch_op ():
777- self .torch_operator (* inputs , ** kwargs )
778- return (
779- kwargs .get ("out" )
780- if "out" in kwargs
781- else inputs [comparison_target ]
782- )
783-
784- torch_time = profile_operation (
785- "PyTorch " ,
786- torch_op ,
787- device_str ,
788- config .num_prerun ,
789- config .num_iterations ,
790- total = True ,
791- )
792-
793- if infini_implemented :
794- if comparison_target is None :
795- # Out-of-place benchmarking
796- def infini_op ():
797- return self .infinicore_operator (* infini_inputs , ** infini_kwargs )
798-
799- else :
800- # In-place benchmarking
801- def infini_op ():
802- self .infinicore_operator (* infini_inputs , ** infini_kwargs )
803- return (
804- infini_kwargs .get ("out" )
805- if "out" in infini_kwargs
806- else infini_inputs [comparison_target ]
807- )
808-
809- infini_time = profile_operation (
810- "InfiniCore" ,
811- infini_op ,
812- device_str ,
813- config .num_prerun ,
814- config .num_iterations ,
815- total = True ,
816- )
817-
818- # Store timing information in the test runner
819- if hasattr (config , "_test_runner" ) and config ._test_runner :
820- # Accumulate total times
821- config ._test_runner .benchmark_times ["torch_total" ] += torch_time
822- config ._test_runner .benchmark_times ["infinicore_total" ] += infini_time
823-
824- return torch_time , infini_time
0 commit comments