1111
1212import torch .cuda
1313
14- from utils import set_seed
14+ from utils import set_seed , clear_l2_cache
15+
1516try :
1617 from task import TestSpec
1718except ImportError :
@@ -24,16 +25,16 @@ class PopcornOutput:
2425 def __init__ (self , fd : int ):
2526 self .file = os .fdopen (fd , 'w' )
2627 os .set_inheritable (fd , False )
27-
28+
2829 def __enter__ (self ):
2930 return self
30-
31+
3132 def __exit__ (self , exc_type , exc_val , exc_tb ):
3233 self .file .close ()
33-
34+
3435 def print (self , * args , ** kwargs ):
3536 print (* args , ** kwargs , file = self .file , flush = True )
36-
37+
3738 def log (self , key , value ):
3839 self .print (f"{ key } : { value } " )
3940
@@ -52,7 +53,7 @@ def _combine(a: int, b: int) -> int:
5253 # so we need to make sure they don't provide any useful info for the full seed.
5354 # This Cantor construction ensures that if the secret seed is a large number,
5455 # then so is the overall seed.
55- return int (a + (a + b ) * ( a + b + 1 ) // 2 )
56+ return int (a + (a + b ) * ( a + b + 1 ) // 2 )
5657
5758
5859def get_test_cases (file_name : str , seed : Optional [int ]) -> list [TestCase ]:
@@ -114,7 +115,7 @@ def calculate_stats(durations: list[int]):
114115 worst = max (durations )
115116
116117 avg = total / runs
117- variance = sum (map (lambda x : (x - avg )** 2 , durations ))
118+ variance = sum (map (lambda x : (x - avg ) ** 2 , durations ))
118119 std = math .sqrt (variance / (runs - 1 ))
119120 err = std / math .sqrt (runs )
120121
@@ -219,6 +220,7 @@ def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_t
219220 # otherwise, we repeat until we either measure at least 10 full seconds,
220221 # or the relative error of the mean is below 1%.
221222
223+ bm_start_time = time .perf_counter_ns ()
222224 for i in range (max_repeats ):
223225 if recheck :
224226 # ensure we use a different seed for every benchmark
@@ -228,28 +230,39 @@ def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_t
228230 data = generate_input (** test .args )
229231 check_copy = _clone_data (data )
230232 torch .cuda .synchronize ()
231- start = time .perf_counter_ns ()
233+ start_event = torch .cuda .Event (enable_timing = True )
234+ end_event = torch .cuda .Event (enable_timing = True )
235+ clear_l2_cache ()
236+
237+ start_event .record ()
232238 output = custom_kernel (data )
239+ end_event .record ()
233240 torch .cuda .synchronize ()
234- end = time . perf_counter_ns ()
241+ duration = start_event . elapsed_time ( end_event ) * 1e6 # Convert ms to ns
235242
236243 if recheck :
237244 good , message = check_implementation (check_copy , output )
238245 if not good :
239246 return message
240247
241248 del output
242- durations .append (end - start )
249+ durations .append (duration )
243250
244251 if i > 1 :
252+ total_bm_duration = time .perf_counter_ns () - bm_start_time
245253 stats = calculate_stats (durations )
246- if stats .err / stats .mean < 0.001 or stats .mean * stats .runs > max_time_ns :
254+ # stop if either
255+ # a) relative error dips below 0.1%
256+ # b) we exceed the total time limit for benchmarking the kernel
257+ # c) we exceed 2 minutes of total wallclock time.
258+ if stats .err / stats .mean < 0.001 or stats .mean * stats .runs > max_time_ns or total_bm_duration > 120e9 :
247259 break
248260
249261 return calculate_stats (durations )
250262
251263
252- def run_single_benchmark (pool : multiprocessing .Pool , test : TestCase , recheck : bool , max_repeats : int , max_time_ns : float ):
264+ def run_single_benchmark (pool : multiprocessing .Pool , test : TestCase , recheck : bool , max_repeats : int ,
265+ max_time_ns : float ):
253266 """
254267 For a particular test case, check correctness (if applicable) and grab runtime results.
255268
@@ -359,7 +372,7 @@ def main():
359372 else :
360373 passed = False
361374 logger .log (f"benchmark.{ i } .status" , "fail" )
362- logger .log (f"benchmark.{ i } .error" , str (result )) # TODO: Make sure result implements __str__?
375+ logger .log (f"benchmark.{ i } .error" , str (result )) # TODO: Make sure result implements __str__?
363376 break
364377
365378 logger .log ("check" , "pass" if passed else "fail" )
0 commit comments