Skip to content

Commit a83ce70

Browse files
committed
Add inductor_benchmarker as latency measurement option
1 parent a404ea7 commit a83ce70

File tree

3 files changed

+68
-9
lines changed

3 files changed

+68
-9
lines changed

tritonbench/components/do_bench/run.py

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77
import triton
8+
from torch._inductor.runtime.benchmarking import benchmarker
89

910
NS_TO_MS = 1e-6
1011

@@ -125,6 +126,42 @@ def _summarize_statistics(times, quantiles, return_mode):
125126
return getattr(torch, return_mode)(times).item()
126127

127128

129+
def _do_bench_inductor(fn, warmup, rep, grad_to_none=None):
130+
"""Measure latency using inductor benchmarker.
131+
132+
Args:
133+
warmup: Target warmup time in milliseconds (matches triton.testing.do_bench)
134+
rep: Target total measurement time in milliseconds (matches triton.testing.do_bench)
135+
grad_to_none: Tensors whose gradients should be cleared before each measurement
136+
137+
Returns:
138+
List of measured times in milliseconds.
139+
"""
140+
# First, estimate the runtime with a single measurement
141+
estimate_ms = benchmarker.benchmark_gpu(fn, estimation_iters=5, benchmark_iters=10)
142+
143+
# Calculate number of iterations based on target rep time
144+
# Similar to how triton.testing.do_bench calculates iterations
145+
if estimate_ms == 0:
146+
n_repeat = 1000 # Default if function is very fast
147+
else:
148+
n_repeat = max(1, int(rep / estimate_ms))
149+
150+
# Collect multiple measurements like triton.testing.do_bench with return_mode='all'
151+
times_ms = []
152+
for _ in range(n_repeat):
153+
# Clear gradients BEFORE timing (like triton.testing.do_bench)
154+
if grad_to_none is not None:
155+
for x in grad_to_none:
156+
x.grad = None
157+
158+
# Measure only the function execution time
159+
ms_time = benchmarker.benchmark_gpu(fn)
160+
times_ms.append(ms_time)
161+
162+
return times_ms
163+
164+
128165
def _do_bench_cpu(
129166
fn, warmup, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"
130167
):
@@ -174,8 +211,13 @@ def do_bench_wrapper(
174211
device: str = "cuda",
175212
use_cuda_graphs: bool = False,
176213
bypass_fail: bool = False,
214+
latency_measure_mode: str = "triton_do_bench",
177215
) -> Optional[Latency]:
178-
"""Wrapper to triton's do_bench to gain latency."""
216+
"""Wrapper to triton's do_bench to gain latency.
217+
218+
Args:
219+
latency_measure_mode: Either "triton_do_bench" (default) or "inductor_benchmarker"
220+
"""
179221
try:
180222
if device == "cpu":
181223
return Latency(
@@ -198,15 +240,25 @@ def do_bench_wrapper(
198240
)
199241
)
200242
else:
201-
return Latency(
202-
times=triton.testing.do_bench(
203-
fn,
204-
warmup=warmup,
205-
rep=rep,
206-
return_mode="all",
207-
grad_to_none=grad_to_none,
243+
if latency_measure_mode == "inductor_benchmarker":
244+
return Latency(
245+
times=_do_bench_inductor(
246+
fn,
247+
warmup=warmup,
248+
rep=rep,
249+
grad_to_none=grad_to_none,
250+
)
251+
)
252+
else: # default to triton do_bench
253+
return Latency(
254+
times=triton.testing.do_bench(
255+
fn,
256+
warmup=warmup,
257+
rep=rep,
258+
return_mode="all",
259+
grad_to_none=grad_to_none,
260+
)
208261
)
209-
)
210262
except Exception as e:
211263
if not bypass_fail:
212264
raise e

tritonbench/utils/parser.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,12 @@ def get_parser(args=None):
179179
parser.add_argument(
180180
"--cudagraph", action="store_true", help="Benchmark with CUDA graph."
181181
)
182+
parser.add_argument(
183+
"--latency-measure-mode",
184+
default="triton_do_bench",
185+
choices=["triton_do_bench", "inductor_benchmarker"],
186+
help="Method to measure latency: triton_do_bench (default) or inductor_benchmarker.",
187+
)
182188
parser.add_argument(
183189
"--isolate",
184190
action="store_true",

tritonbench/utils/triton_op.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,6 +1229,7 @@ def _init_extra_metrics() -> Dict[str, Any]:
12291229
device=self.device,
12301230
use_cuda_graphs=self.use_cuda_graphs,
12311231
bypass_fail=self.tb_args.bypass_fail,
1232+
latency_measure_mode=self.tb_args.latency_measure_mode,
12321233
)
12331234
if {
12341235
"gpu_peak_mem",

0 commit comments

Comments
 (0)