@@ -269,6 +269,7 @@ def benchmarks(self) -> list[Benchmark]:
269269 )
270270 )
271271
272+ # Add RecordAndReplay benchmarks
272273 record_and_replay_params = product ([0 , 1 ], [0 , 1 ])
273274 for emulate , instantiate in record_and_replay_params :
274275
@@ -315,6 +316,39 @@ def createRrBench(variant_name: str, **kwargs):
315316 ),
316317 ]
317318
319+ # Add TorchMultiQueue benchmarks
320+ for runtime in filter (lambda x : x != RUNTIMES .UR , RUNTIMES ):
321+
322+ def createTorchMultiQueueBench (variant_name : str , ** kwargs ):
323+ return TorchMultiQueue (
324+ self ,
325+ runtime ,
326+ variant_name ,
327+ PROFILERS .TIMER ,
328+ ** kwargs ,
329+ )
330+
331+ benches += [
332+ createTorchMultiQueueBench (
333+ "large" ,
334+ workgroupCount = 4096 ,
335+ workgroupSize = 512 ,
336+ kernelsPerQueue = 20 ,
337+ ),
338+ createTorchMultiQueueBench (
339+ "medium" ,
340+ workgroupCount = 512 ,
341+ workgroupSize = 256 ,
342+ kernelsPerQueue = 10 ,
343+ ),
344+ createTorchMultiQueueBench (
345+ "small" ,
346+ workgroupCount = 256 ,
347+ workgroupSize = 124 ,
348+ kernelsPerQueue = 4 ,
349+ ),
350+ ]
351+
318352 # Add UR-specific benchmarks
319353 benches += [
320354 # TODO: multithread_benchmark_ur fails with segfault
@@ -770,6 +804,48 @@ def _bin_args(self, run_trace: TracingType = TracingType.NONE) -> list[str]:
770804 return [f"--{ k } ={ v } " for k , v in self ._rr_params .items ()]
771805
772806
807+ class TorchMultiQueue (ComputeBenchmark ):
808+ def __init__ (
809+ self , suite , runtime : RUNTIMES , variant_name : str , profiler_type , ** kwargs
810+ ):
811+ self ._variant_name = variant_name
812+ self ._smq_params = kwargs
813+ self ._iterations_regular = 1000
814+ self ._iterations_trace = 10
815+ super ().__init__ (
816+ suite ,
817+ f"torch_benchmark_{ runtime .value } " ,
818+ "KernelSubmitMultiQueue" ,
819+ runtime ,
820+ profiler_type ,
821+ )
822+
823+ def name (self ):
824+ ret = []
825+ for k , v in self ._smq_params .items ():
826+ ret .append (f"{ k } { v } " )
827+ ret .sort ()
828+ return self ._bench_name + " " + ", " .join (ret )
829+
830+ def display_name (self ) -> str :
831+ return f"{ self .explicit_group ()} { self ._runtime .value } "
832+
833+ def explicit_group (self ):
834+ return f"{ self ._test } { self ._variant_name } "
835+
836+ def get_tags (self ):
837+ return ["pytorch" , runtime_to_tag_name (self ._runtime )]
838+
839+ def _supported_runtimes (self ) -> list [RUNTIMES ]:
840+ return super ()._supported_runtimes () + [RUNTIMES .SYCL_PREVIEW ]
841+
842+ def _bin_args (self , run_trace : TracingType = TracingType .NONE ) -> list [str ]:
843+ iters = self ._get_iters (run_trace )
844+ return [f"--iterations={ iters } " ] + [
845+ f"--{ k } ={ v } " for k , v in self ._smq_params .items ()
846+ ]
847+
848+
773849class QueueInOrderMemcpy (ComputeBenchmark ):
774850 def __init__ (self , bench , isCopyOnly , source , destination , size , profiler_type ):
775851 self ._is_copy_only = isCopyOnly
0 commit comments