2323from kernel_tuner .backends .opencl import OpenCLFunctions
2424from kernel_tuner .backends .hip import HipFunctions
2525from kernel_tuner .observers .nvml import NVMLObserver
26- from kernel_tuner .observers .observer import ContinuousObserver , OutputObserver
26+ from kernel_tuner .observers .observer import ContinuousObserver , OutputObserver , PrologueObserver
2727
2828try :
2929 import torch
@@ -314,11 +314,13 @@ def __init__(
314314 )
315315 else :
316316 raise ValueError ("Sorry, support for languages other than CUDA, OpenCL, HIP, C, and Fortran is not implemented yet" )
317+ self .dev = dev
317318
318319 # look for NVMLObserver in observers, if present, enable special tunable parameters through nvml
319320 self .use_nvml = False
320321 self .continuous_observers = []
321322 self .output_observers = []
323+ self .prologue_observers = []
322324 if observers :
323325 for obs in observers :
324326 if isinstance (obs , NVMLObserver ):
@@ -328,49 +330,61 @@ def __init__(
328330 self .continuous_observers .append (obs .continuous_observer )
329331 if isinstance (obs , OutputObserver ):
330332 self .output_observers .append (obs )
333+ if isinstance (obs , PrologueObserver ):
334+ self .prologue_observers .append (obs )
331335
336+ # Take list of observers from self.dev because Backends tend to add their own observer
337+ self .benchmark_observers = [
338+ obs for obs in self .dev .observers if not isinstance (obs , (ContinuousObserver , PrologueObserver ))
339+ ]
332340
333341 self .iterations = iterations
334342
335343 self .lang = lang
336- self .dev = dev
337344 self .units = dev .units
338345 self .name = dev .name
339346 self .max_threads = dev .max_threads
340347 if not quiet :
341348 print ("Using: " + self .dev .name )
342349
350+ def benchmark_prologue (self , func , gpu_args , threads , grid , result ):
351+ """Benchmark prologue one kernel execution per PrologueObserver"""
352+
353+ for obs in self .prologue_observers :
354+ self .dev .synchronize ()
355+ obs .before_start ()
356+ self .dev .run_kernel (func , gpu_args , threads , grid )
357+ self .dev .synchronize ()
358+ obs .after_finish ()
359+ result .update (obs .get_results ())
360+
343361 def benchmark_default (self , func , gpu_args , threads , grid , result ):
344- """Benchmark one kernel execution at a time."""
345- observers = [
346- obs for obs in self .dev .observers if not isinstance (obs , ContinuousObserver )
347- ]
362+ """Benchmark one kernel execution for 'iterations' at a time"""
348363
349364 self .dev .synchronize ()
350365 for _ in range (self .iterations ):
351- for obs in observers :
366+ for obs in self . benchmark_observers :
352367 obs .before_start ()
353368 self .dev .synchronize ()
354369 self .dev .start_event ()
355370 self .dev .run_kernel (func , gpu_args , threads , grid )
356371 self .dev .stop_event ()
357- for obs in observers :
372+ for obs in self . benchmark_observers :
358373 obs .after_start ()
359374 while not self .dev .kernel_finished ():
360- for obs in observers :
375+ for obs in self . benchmark_observers :
361376 obs .during ()
362377 time .sleep (1e-6 ) # one microsecond
363378 self .dev .synchronize ()
364- for obs in observers :
379+ for obs in self . benchmark_observers :
365380 obs .after_finish ()
366381
367- for obs in observers :
382+ for obs in self . benchmark_observers :
368383 result .update (obs .get_results ())
369384
370385 def benchmark_continuous (self , func , gpu_args , threads , grid , result , duration ):
371386 """Benchmark continuously for at least 'duration' seconds"""
372387 iterations = int (np .ceil (duration / (result ["time" ] / 1000 )))
373- # print(f"{iterations=} {(result['time']/1000)=}")
374388 self .dev .synchronize ()
375389 for obs in self .continuous_observers :
376390 obs .before_start ()
@@ -420,9 +434,8 @@ def benchmark(self, func, gpu_args, instance, verbose, objective, skip_nvml_sett
420434
421435 result = {}
422436 try :
423- self .benchmark_default (
424- func , gpu_args , instance .threads , instance .grid , result
425- )
437+ self .benchmark_prologue (func , gpu_args , instance .threads , instance .grid , result )
438+ self .benchmark_default (func , gpu_args , instance .threads , instance .grid , result )
426439
427440 if self .continuous_observers :
428441 duration = 1
0 commit comments