Skip to content

Commit a696371

Browse files
yf225can-gaa-hou
authored andcommitted
Make Inductor benchmarker more compatible with Triton do_bench (pytorch#160921)
Common benchmark suites like TritonBench uses `triton.testing.do_bench` for kernel timing measurement which is not always fair for all backends. E.g. it includes torch.compile Dynamo invocation overhead and hence doesn't reflect real-world model use case where Dynamo overhead is usually hidden. I also opened a PR to use this timing measurement function on TritonBench side: meta-pytorch/tritonbench#333. But regardless of whether that PR can land, I think we should enhance Inductor benchmark_gpu to match do_bench features, to make it easier to people to migrate. Pull Request resolved: pytorch#160921 Approved by: https://github.com/BoyuanFeng
1 parent 7439a54 commit a696371

File tree

1 file changed

+37
-9
lines changed

1 file changed

+37
-9
lines changed

torch/_inductor/runtime/benchmarking.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from functools import cached_property, wraps
44
from itertools import chain
55
from statistics import median
6-
from typing import Any, Callable
6+
from typing import Any, Callable, Optional, Union
77
from typing_extensions import Concatenate, ParamSpec, Self, TypeVar
88

99
import torch
@@ -173,7 +173,7 @@ def benchmark_gpu(self: Self, _callable: Callable[[], Any], **kwargs: Any) -> fl
173173
return self.triton_do_bench(_callable, **kwargs, return_mode="median")
174174

175175

176-
class InductorBenchmarker(TritonBenchmarker):
176+
class InductorBenchmarker(TritonBenchmarker): # noqa: docstring_linter
177177
@cached_property
178178
def L2_cache_size(self: Self) -> int:
179179
"""Get the L2 cache size, in bytes, of the current device."""
@@ -205,15 +205,17 @@ def get_event_pairs_min_timing(
205205
)
206206

207207
@time_and_count
208-
def benchmark_gpu(
208+
def benchmark_gpu( # type: ignore[override]
209209
self: Self,
210210
_callable: Callable[[], Any],
211211
estimation_iters: int = 5,
212212
memory_warmup_iters: int = 100,
213213
benchmark_iters: int = 100,
214214
max_benchmark_duration: int = 25,
215+
return_mode: str = "min",
216+
grad_to_none: Optional[list[torch.Tensor]] = None,
215217
**kwargs: Any,
216-
) -> float:
218+
) -> Union[float, list[float]]:
217219
"""Benchmark a GPU callable using a custom benchmarking implementation.
218220
219221
Arguments:
@@ -231,10 +233,15 @@ def benchmark_gpu(
231233
of `memory_warmup_iters` and `benchmark_iters`, along with the estimated
232234
runtime of `_callable` and various other factors, and we then shrink
233235
`benchmark_iters` to fit in the allotted maximum duration.
236+
- return_mode: Return mode for benchmark results. Options are "min" (default),
237+
"all" (returns all measurements).
238+
- grad_to_none: Optionally, a list of tensors whose gradients should be cleared
239+
before each benchmark iteration.
234240
- **kwargs: Additional kwargs that may be passed to the fallback.
235241
236242
Returns:
237-
- The minimum runtime of `_callable`, in milliseconds.
243+
- If return_mode="min": The minimum runtime of `_callable`, in milliseconds.
244+
- If return_mode="all": List of all runtime measurements, in milliseconds.
238245
"""
239246
# we don't want any outside errors propagating into benchmarking
240247
torch.cuda.synchronize()
@@ -250,6 +257,10 @@ def benchmark_gpu(
250257
# estimate the runtime of `_callable`
251258
event_pairs = self.get_event_pairs(estimation_iters)
252259
for start_event, end_event in event_pairs:
260+
# Clear gradients before timing (matches triton.testing.do_bench)
261+
if grad_to_none is not None:
262+
for x in grad_to_none:
263+
x.grad = None
253264
buffer.zero_()
254265
start_event.record()
255266
_callable()
@@ -269,20 +280,37 @@ def benchmark_gpu(
269280
# benchmark `_callable`
270281
event_pairs = self.get_event_pairs(benchmark_iters)
271282
for start_event, end_event in event_pairs:
283+
# Clear gradients before timing (matches triton.testing.do_bench)
284+
if grad_to_none is not None:
285+
for x in grad_to_none:
286+
x.grad = None
272287
buffer.zero_()
273288
start_event.record()
274289
_callable()
275290
end_event.record()
276291
torch.cuda.synchronize()
277-
benchmarked_timing = self.get_event_pairs_min_timing(event_pairs)
278292

279293
# explicitly delete the buffer, sometimes helps memory
280294
# footprint metrics in OSS Inductor performance benchmarks
281295
del buffer
282296

283-
# return the minimum of `estimated_timing` and `benchmarked_timing`,
284-
# we just want the minimum timing overall so we might as well check both
285-
return min(estimated_timing, benchmarked_timing)
297+
# Return based on the requested mode
298+
if return_mode == "all":
299+
# Get all timings from event pairs
300+
all_timings = [
301+
start_event.elapsed_time(end_event)
302+
for start_event, end_event in event_pairs
303+
]
304+
return all_timings
305+
elif return_mode == "min":
306+
benchmarked_timing = self.get_event_pairs_min_timing(event_pairs)
307+
# return the minimum of `estimated_timing` and `benchmarked_timing`,
308+
# we just want the minimum timing overall so we might as well check both
309+
return min(estimated_timing, benchmarked_timing)
310+
else:
311+
raise ValueError(
312+
f"Unsupported return_mode: {return_mode}. Use 'min' or 'all'."
313+
)
286314

287315

288316
benchmarker = (

0 commit comments

Comments
 (0)