Skip to content

Commit 0ca9cda

Browse files
Remove workaround in testing.py (#4846)
Since we are now using DLE 2025.1.1, we can remove the workaround. --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 2ed5768 commit 0ca9cda

File tree

2 files changed

+6
-88
lines changed

2 files changed

+6
-88
lines changed

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import torch
2020
from torch.profiler import profile, ProfilerActivity, record_function
2121

22-
import triton
2322
from triton.testing import assert_close as triton_assert_close, Benchmark, do_bench as triton_do_bench
2423

2524
from triton_kernels_benchmark import build_report
@@ -93,9 +92,6 @@ def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quan
9392
fn()
9493
end_event.record()
9594
synchronize()
96-
# FIXME: to avoid negative timings before DLE 2025.1;
97-
# this workaround doesn't work for BMG.
98-
triton.runtime.driver.active.utils.wait()
9995
estimate_ms = start_event.elapsed_time(end_event) / 5
10096

10197
# The cache is also maintained in `triton_do_bench` function,

python/triton/testing.py

Lines changed: 6 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -8,65 +8,6 @@
88
from typing import Any, Dict, List
99
from . import language as tl
1010
from . import runtime
11-
import time
12-
import logging
13-
14-
15-
@functools.cache
16-
def _support_elapsed_time():
17-
import torch
18-
import triton
19-
20-
support = True
21-
message_unsupported = "Wall time is used instead of elapsed_time (not supported). \
22-
The timing measurements could be innacurate."
23-
24-
message_bug = "Wall time is used instead of elapsed_time because of the bug ('negative timings'). \
25-
Should be fixed in DLE 2025.1. The timing measurements could be innacurate."
26-
27-
# FIXME: 5 iterations to detect negative timings bug; 1 should be enough for DLE 2025.1
28-
for _ in range(5):
29-
e1 = torch.xpu.Event(enable_timing=True)
30-
e1.record()
31-
32-
e2 = torch.xpu.Event(enable_timing=True)
33-
e2.record()
34-
35-
try:
36-
# FIXME: to avoid negative timings before DLE 2025.1;
37-
# this workaround doesn't work for BMG.
38-
triton.runtime.driver.active.utils.wait()
39-
if e1.elapsed_time(e2) <= 0:
40-
logging.warning(message_bug)
41-
support = False
42-
break
43-
except Exception:
44-
logging.warning(message_unsupported)
45-
support = False
46-
break
47-
48-
return support
49-
50-
51-
class WallEvent():
52-
53-
def __init__(self, **kwargs):
54-
self.record()
55-
56-
def record(self):
57-
self.timestamp = time.perf_counter_ns() / 1_000_000
58-
59-
def elapsed_time(self, end):
60-
return end.timestamp - self.timestamp
61-
62-
63-
def Event(**kwargs):
64-
if _support_elapsed_time():
65-
import torch
66-
67-
return torch.xpu.Event(**kwargs)
68-
else:
69-
return WallEvent(**kwargs)
7011

7112

7213
def nvsmi(attrs):
@@ -202,7 +143,6 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
202143
:type return_mode: str
203144
"""
204145
assert return_mode in ["min", "max", "mean", "median", "all"]
205-
import triton
206146

207147
di = runtime.driver.active.get_device_interface()
208148

@@ -212,30 +152,21 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
212152
cache = runtime.driver.active.get_empty_cache_for_benchmark()
213153

214154
# Estimate the runtime of the function
215-
start_event = Event(enable_timing=True)
216-
end_event = Event(enable_timing=True)
217-
USE_WALL_TIME = isinstance(start_event, WallEvent)
218-
155+
start_event = di.Event(enable_timing=True)
156+
end_event = di.Event(enable_timing=True)
219157
start_event.record()
220158
for _ in range(5):
221159
runtime.driver.active.clear_cache(cache)
222160
fn()
223-
if USE_WALL_TIME:
224-
di.synchronize()
225161
end_event.record()
226-
if not USE_WALL_TIME:
227-
di.synchronize()
228-
229-
# FIXME: to avoid negative timings before DLE 2025.1;
230-
# this workaround doesn't work for BMG.
231-
triton.runtime.driver.active.utils.wait()
162+
di.synchronize()
232163
estimate_ms = start_event.elapsed_time(end_event) / 5
233164

234165
# compute number of warmup and repeat
235166
n_warmup = max(1, int(warmup / estimate_ms))
236167
n_repeat = max(1, int(rep / estimate_ms))
237-
start_event = [Event(enable_timing=True) for i in range(n_repeat)]
238-
end_event = [Event(enable_timing=True) for i in range(n_repeat)]
168+
start_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
169+
end_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
239170
# Warm-up
240171
for _ in range(n_warmup):
241172
fn()
@@ -249,21 +180,12 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
249180
x.grad = None
250181
# we clear the L2 cache before each run
251182
runtime.driver.active.clear_cache(cache)
252-
if USE_WALL_TIME:
253-
di.synchronize()
254183
# record time of `fn`
255184
start_event[i].record()
256185
fn()
257-
if USE_WALL_TIME:
258-
di.synchronize()
259186
end_event[i].record()
260187
# Record clocks
261-
if not USE_WALL_TIME:
262-
di.synchronize()
263-
264-
# FIXME: to avoid negative timings before DLE 2025.1;
265-
# this workaround doesn't work for BMG.
266-
triton.runtime.driver.active.utils.wait()
188+
di.synchronize()
267189
times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)]
268190
return _summarize_statistics(times, quantiles, return_mode)
269191

0 commit comments

Comments
 (0)