88from typing import Any , Dict , List
99from . import language as tl
1010from . import runtime
11- import time
12- import logging
13-
14-
15- @functools .cache
16- def _support_elapsed_time ():
17- import torch
18- import triton
19-
20- support = True
21- message_unsupported = "Wall time is used instead of elapsed_time (not supported). \
22- The timing measurements could be innacurate."
23-
24- message_bug = "Wall time is used instead of elapsed_time because of the bug ('negative timings'). \
25- Should be fixed in DLE 2025.1. The timing measurements could be innacurate."
26-
27- # FIXME: 5 iterations to detect negative timings bug; 1 should be enough for DLE 2025.1
28- for _ in range (5 ):
29- e1 = torch .xpu .Event (enable_timing = True )
30- e1 .record ()
31-
32- e2 = torch .xpu .Event (enable_timing = True )
33- e2 .record ()
34-
35- try :
36- # FIXME: to avoid negative timings before DLE 2025.1;
37- # this workaround doesn't work for BMG.
38- triton .runtime .driver .active .utils .wait ()
39- if e1 .elapsed_time (e2 ) <= 0 :
40- logging .warning (message_bug )
41- support = False
42- break
43- except Exception :
44- logging .warning (message_unsupported )
45- support = False
46- break
47-
48- return support
49-
50-
51- class WallEvent ():
52-
53- def __init__ (self , ** kwargs ):
54- self .record ()
55-
56- def record (self ):
57- self .timestamp = time .perf_counter_ns () / 1_000_000
58-
59- def elapsed_time (self , end ):
60- return end .timestamp - self .timestamp
61-
62-
63- def Event (** kwargs ):
64- if _support_elapsed_time ():
65- import torch
66-
67- return torch .xpu .Event (** kwargs )
68- else :
69- return WallEvent (** kwargs )
7011
7112
7213def nvsmi (attrs ):
@@ -202,7 +143,6 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
202143 :type return_mode: str
203144 """
204145 assert return_mode in ["min" , "max" , "mean" , "median" , "all" ]
205- import triton
206146
207147 di = runtime .driver .active .get_device_interface ()
208148
@@ -212,30 +152,21 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
212152 cache = runtime .driver .active .get_empty_cache_for_benchmark ()
213153
214154 # Estimate the runtime of the function
215- start_event = Event (enable_timing = True )
216- end_event = Event (enable_timing = True )
217- USE_WALL_TIME = isinstance (start_event , WallEvent )
218-
155+ start_event = di .Event (enable_timing = True )
156+ end_event = di .Event (enable_timing = True )
219157 start_event .record ()
220158 for _ in range (5 ):
221159 runtime .driver .active .clear_cache (cache )
222160 fn ()
223- if USE_WALL_TIME :
224- di .synchronize ()
225161 end_event .record ()
226- if not USE_WALL_TIME :
227- di .synchronize ()
228-
229- # FIXME: to avoid negative timings before DLE 2025.1;
230- # this workaround doesn't work for BMG.
231- triton .runtime .driver .active .utils .wait ()
162+ di .synchronize ()
232163 estimate_ms = start_event .elapsed_time (end_event ) / 5
233164
234165 # compute number of warmup and repeat
235166 n_warmup = max (1 , int (warmup / estimate_ms ))
236167 n_repeat = max (1 , int (rep / estimate_ms ))
237- start_event = [Event (enable_timing = True ) for i in range (n_repeat )]
238- end_event = [Event (enable_timing = True ) for i in range (n_repeat )]
168+ start_event = [di . Event (enable_timing = True ) for i in range (n_repeat )]
169+ end_event = [di . Event (enable_timing = True ) for i in range (n_repeat )]
239170 # Warm-up
240171 for _ in range (n_warmup ):
241172 fn ()
@@ -249,21 +180,12 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
249180 x .grad = None
250181 # we clear the L2 cache before each run
251182 runtime .driver .active .clear_cache (cache )
252- if USE_WALL_TIME :
253- di .synchronize ()
254183 # record time of `fn`
255184 start_event [i ].record ()
256185 fn ()
257- if USE_WALL_TIME :
258- di .synchronize ()
259186 end_event [i ].record ()
260187 # Record clocks
261- if not USE_WALL_TIME :
262- di .synchronize ()
263-
264- # FIXME: to avoid negative timings before DLE 2025.1;
265- # this workaround doesn't work for BMG.
266- triton .runtime .driver .active .utils .wait ()
188+ di .synchronize ()
267189 times = [s .elapsed_time (e ) for s , e in zip (start_event , end_event )]
268190 return _summarize_statistics (times , quantiles , return_mode )
269191
0 commit comments