8
8
from typing import Any , Dict , List
9
9
from . import language as tl
10
10
from . import runtime
11
- import time
12
- import logging
13
-
14
-
15
- @functools .cache
16
- def _support_elapsed_time ():
17
- import torch
18
- import triton
19
-
20
- support = True
21
- message_unsupported = "Wall time is used instead of elapsed_time (not supported). \
22
- The timing measurements could be innacurate."
23
-
24
- message_bug = "Wall time is used instead of elapsed_time because of the bug ('negative timings'). \
25
- Should be fixed in DLE 2025.1. The timing measurements could be innacurate."
26
-
27
- # FIXME: 5 iterations to detect negative timings bug; 1 should be enough for DLE 2025.1
28
- for _ in range (5 ):
29
- e1 = torch .xpu .Event (enable_timing = True )
30
- e1 .record ()
31
-
32
- e2 = torch .xpu .Event (enable_timing = True )
33
- e2 .record ()
34
-
35
- try :
36
- # FIXME: to avoid negative timings before DLE 2025.1;
37
- # this workaround doesn't work for BMG.
38
- triton .runtime .driver .active .utils .wait ()
39
- if e1 .elapsed_time (e2 ) <= 0 :
40
- logging .warning (message_bug )
41
- support = False
42
- break
43
- except Exception :
44
- logging .warning (message_unsupported )
45
- support = False
46
- break
47
-
48
- return support
49
-
50
-
51
- class WallEvent ():
52
-
53
- def __init__ (self , ** kwargs ):
54
- self .record ()
55
-
56
- def record (self ):
57
- self .timestamp = time .perf_counter_ns () / 1_000_000
58
-
59
- def elapsed_time (self , end ):
60
- return end .timestamp - self .timestamp
61
-
62
-
63
- def Event (** kwargs ):
64
- if _support_elapsed_time ():
65
- import torch
66
-
67
- return torch .xpu .Event (** kwargs )
68
- else :
69
- return WallEvent (** kwargs )
70
11
71
12
72
13
def nvsmi (attrs ):
@@ -202,7 +143,6 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
202
143
:type return_mode: str
203
144
"""
204
145
assert return_mode in ["min" , "max" , "mean" , "median" , "all" ]
205
- import triton
206
146
207
147
di = runtime .driver .active .get_device_interface ()
208
148
@@ -212,30 +152,21 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
212
152
cache = runtime .driver .active .get_empty_cache_for_benchmark ()
213
153
214
154
# Estimate the runtime of the function
215
- start_event = Event (enable_timing = True )
216
- end_event = Event (enable_timing = True )
217
- USE_WALL_TIME = isinstance (start_event , WallEvent )
218
-
155
+ start_event = di .Event (enable_timing = True )
156
+ end_event = di .Event (enable_timing = True )
219
157
start_event .record ()
220
158
for _ in range (5 ):
221
159
runtime .driver .active .clear_cache (cache )
222
160
fn ()
223
- if USE_WALL_TIME :
224
- di .synchronize ()
225
161
end_event .record ()
226
- if not USE_WALL_TIME :
227
- di .synchronize ()
228
-
229
- # FIXME: to avoid negative timings before DLE 2025.1;
230
- # this workaround doesn't work for BMG.
231
- triton .runtime .driver .active .utils .wait ()
162
+ di .synchronize ()
232
163
estimate_ms = start_event .elapsed_time (end_event ) / 5
233
164
234
165
# compute number of warmup and repeat
235
166
n_warmup = max (1 , int (warmup / estimate_ms ))
236
167
n_repeat = max (1 , int (rep / estimate_ms ))
237
- start_event = [Event (enable_timing = True ) for i in range (n_repeat )]
238
- end_event = [Event (enable_timing = True ) for i in range (n_repeat )]
168
+ start_event = [di . Event (enable_timing = True ) for i in range (n_repeat )]
169
+ end_event = [di . Event (enable_timing = True ) for i in range (n_repeat )]
239
170
# Warm-up
240
171
for _ in range (n_warmup ):
241
172
fn ()
@@ -249,21 +180,12 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
249
180
x .grad = None
250
181
# we clear the L2 cache before each run
251
182
runtime .driver .active .clear_cache (cache )
252
- if USE_WALL_TIME :
253
- di .synchronize ()
254
183
# record time of `fn`
255
184
start_event [i ].record ()
256
185
fn ()
257
- if USE_WALL_TIME :
258
- di .synchronize ()
259
186
end_event [i ].record ()
260
187
# Record clocks
261
- if not USE_WALL_TIME :
262
- di .synchronize ()
263
-
264
- # FIXME: to avoid negative timings before DLE 2025.1;
265
- # this workaround doesn't work for BMG.
266
- triton .runtime .driver .active .utils .wait ()
188
+ di .synchronize ()
267
189
times = [s .elapsed_time (e ) for s , e in zip (start_event , end_event )]
268
190
return _summarize_statistics (times , quantiles , return_mode )
269
191
0 commit comments