3
3
from functools import cached_property , wraps
4
4
from itertools import chain
5
5
from statistics import median
6
- from typing import Any , Callable
6
+ from typing import Any , Callable , Optional , Union
7
7
from typing_extensions import Concatenate , ParamSpec , Self , TypeVar
8
8
9
9
import torch
@@ -173,7 +173,7 @@ def benchmark_gpu(self: Self, _callable: Callable[[], Any], **kwargs: Any) -> fl
173
173
return self .triton_do_bench (_callable , ** kwargs , return_mode = "median" )
174
174
175
175
176
- class InductorBenchmarker (TritonBenchmarker ):
176
+ class InductorBenchmarker (TritonBenchmarker ): # noqa: docstring_linter
177
177
@cached_property
178
178
def L2_cache_size (self : Self ) -> int :
179
179
"""Get the L2 cache size, in bytes, of the current device."""
@@ -205,15 +205,17 @@ def get_event_pairs_min_timing(
205
205
)
206
206
207
207
@time_and_count
208
- def benchmark_gpu (
208
+ def benchmark_gpu ( # type: ignore[override]
209
209
self : Self ,
210
210
_callable : Callable [[], Any ],
211
211
estimation_iters : int = 5 ,
212
212
memory_warmup_iters : int = 100 ,
213
213
benchmark_iters : int = 100 ,
214
214
max_benchmark_duration : int = 25 ,
215
+ return_mode : str = "min" ,
216
+ grad_to_none : Optional [list [torch .Tensor ]] = None ,
215
217
** kwargs : Any ,
216
- ) -> float :
218
+ ) -> Union [ float , list [ float ]] :
217
219
"""Benchmark a GPU callable using a custom benchmarking implementation.
218
220
219
221
Arguments:
@@ -231,10 +233,15 @@ def benchmark_gpu(
231
233
of `memory_warmup_iters` and `benchmark_iters`, along with the estimated
232
234
runtime of `_callable` and various other factors, and we then shrink
233
235
`benchmark_iters` to fit in the allotted maximum duration.
236
+ - return_mode: Return mode for benchmark results. Options are "min" (default),
237
+ "all" (returns all measurements).
238
+ - grad_to_none: Optionally, a list of tensors whose gradients should be cleared
239
+ before each benchmark iteration.
234
240
- **kwargs: Additional kwargs that may be passed to the fallback.
235
241
236
242
Returns:
237
- - The minimum runtime of `_callable`, in milliseconds.
243
+ - If return_mode="min": The minimum runtime of `_callable`, in milliseconds.
244
+ - If return_mode="all": List of all runtime measurements, in milliseconds.
238
245
"""
239
246
# we don't want any outside errors propagating into benchmarking
240
247
torch .cuda .synchronize ()
@@ -250,6 +257,10 @@ def benchmark_gpu(
250
257
# estimate the runtime of `_callable`
251
258
event_pairs = self .get_event_pairs (estimation_iters )
252
259
for start_event , end_event in event_pairs :
260
+ # Clear gradients before timing (matches triton.testing.do_bench)
261
+ if grad_to_none is not None :
262
+ for x in grad_to_none :
263
+ x .grad = None
253
264
buffer .zero_ ()
254
265
start_event .record ()
255
266
_callable ()
@@ -269,20 +280,37 @@ def benchmark_gpu(
269
280
# benchmark `_callable`
270
281
event_pairs = self .get_event_pairs (benchmark_iters )
271
282
for start_event , end_event in event_pairs :
283
+ # Clear gradients before timing (matches triton.testing.do_bench)
284
+ if grad_to_none is not None :
285
+ for x in grad_to_none :
286
+ x .grad = None
272
287
buffer .zero_ ()
273
288
start_event .record ()
274
289
_callable ()
275
290
end_event .record ()
276
291
torch .cuda .synchronize ()
277
- benchmarked_timing = self .get_event_pairs_min_timing (event_pairs )
278
292
279
293
# explicitly delete the buffer, sometimes helps memory
280
294
# footprint metrics in OSS Inductor performance benchmarks
281
295
del buffer
282
296
283
- # return the minimum of `estimated_timing` and `benchmarked_timing`,
284
- # we just want the minimum timing overall so we might as well check both
285
- return min (estimated_timing , benchmarked_timing )
297
+ # Return based on the requested mode
298
+ if return_mode == "all" :
299
+ # Get all timings from event pairs
300
+ all_timings = [
301
+ start_event .elapsed_time (end_event )
302
+ for start_event , end_event in event_pairs
303
+ ]
304
+ return all_timings
305
+ elif return_mode == "min" :
306
+ benchmarked_timing = self .get_event_pairs_min_timing (event_pairs )
307
+ # return the minimum of `estimated_timing` and `benchmarked_timing`,
308
+ # we just want the minimum timing overall so we might as well check both
309
+ return min (estimated_timing , benchmarked_timing )
310
+ else :
311
+ raise ValueError (
312
+ f"Unsupported return_mode: { return_mode } . Use 'min' or 'all'."
313
+ )
286
314
287
315
288
316
benchmarker = (
0 commit comments