Skip to content

Commit 1a758b2

Browse files
committed
Add inductor_benchmarker as latency measurement option
1 parent 98945c7 commit 1a758b2

File tree

3 files changed

+67
-9
lines changed

3 files changed

+67
-9
lines changed

tritonbench/components/do_bench/run.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77
import triton
8+
from torch._inductor.runtime.benchmarking import benchmarker
89

910
NS_TO_MS = 1e-6
1011

@@ -125,6 +126,41 @@ def _summarize_statistics(times, quantiles, return_mode):
125126
return getattr(torch, return_mode)(times).item()
126127

127128

129+
def _do_bench_inductor(fn, warmup, rep, grad_to_none=None):
130+
"""Measure latency using inductor benchmarker.
131+
132+
Args:
133+
warmup: Target warmup time in milliseconds (matches triton.testing.do_bench)
134+
rep: Target total measurement time in milliseconds (matches triton.testing.do_bench)
135+
grad_to_none: Tensors whose gradients should be cleared before each measurement
136+
137+
Returns:
138+
List of measured times in milliseconds.
139+
"""
140+
# First, estimate the runtime with a single measurement
141+
estimate_ms = benchmarker.benchmark_gpu(fn, estimation_iters=5, benchmark_iters=10)
142+
143+
# Calculate number of iterations based on target rep time
144+
# Similar to how triton.testing.do_bench calculates iterations
145+
if estimate_ms == 0:
146+
n_repeat = 1000 # Default if function is very fast
147+
else:
148+
n_repeat = max(1, int(rep / estimate_ms))
149+
150+
# Collect multiple measurements like triton.testing.do_bench with return_mode='all'
151+
times_ms = []
152+
for _ in range(n_repeat):
153+
if grad_to_none is not None:
154+
for x in grad_to_none:
155+
x.grad = None
156+
157+
# Measure only the function execution time
158+
ms_time = benchmarker.benchmark_gpu(fn)
159+
times_ms.append(ms_time)
160+
161+
return times_ms
162+
163+
128164
def _do_bench_cpu(
129165
fn, warmup, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"
130166
):
@@ -174,8 +210,13 @@ def do_bench_wrapper(
174210
device: str = "cuda",
175211
use_cuda_graphs: bool = False,
176212
bypass_fail: bool = False,
213+
latency_measure_mode: str = "triton_do_bench",
177214
) -> Optional[Latency]:
178-
"""Wrapper to triton's do_bench to gain latency."""
215+
"""Wrapper to triton's do_bench to gain latency.
216+
217+
Args:
218+
latency_measure_mode: Either "triton_do_bench" (default) or "inductor_benchmarker"
219+
"""
179220
try:
180221
if device == "cpu":
181222
return Latency(
@@ -198,15 +239,25 @@ def do_bench_wrapper(
198239
)
199240
)
200241
else:
201-
return Latency(
202-
times=triton.testing.do_bench(
203-
fn,
204-
warmup=warmup,
205-
rep=rep,
206-
return_mode="all",
207-
grad_to_none=grad_to_none,
242+
if latency_measure_mode == "inductor_benchmarker":
243+
return Latency(
244+
times=_do_bench_inductor(
245+
fn,
246+
warmup=warmup,
247+
rep=rep,
248+
grad_to_none=grad_to_none,
249+
)
250+
)
251+
else: # default to triton do_bench
252+
return Latency(
253+
times=triton.testing.do_bench(
254+
fn,
255+
warmup=warmup,
256+
rep=rep,
257+
return_mode="all",
258+
grad_to_none=grad_to_none,
259+
)
208260
)
209-
)
210261
except Exception as e:
211262
if not bypass_fail:
212263
raise e

tritonbench/utils/parser.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,12 @@ def get_parser(args=None):
185185
parser.add_argument(
186186
"--cudagraph", action="store_true", help="Benchmark with CUDA graph."
187187
)
188+
parser.add_argument(
189+
"--latency-measure-mode",
190+
default="triton_do_bench",
191+
choices=["triton_do_bench", "inductor_benchmarker"],
192+
help="Method to measure latency: triton_do_bench (default) or inductor_benchmarker.",
193+
)
188194
parser.add_argument(
189195
"--isolate",
190196
action="store_true",

tritonbench/utils/triton_op.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,6 +1268,7 @@ def _init_extra_metrics() -> Dict[str, Any]:
12681268
device=self.device,
12691269
use_cuda_graphs=self.use_cuda_graphs,
12701270
bypass_fail=self.tb_args.bypass_fail,
1271+
latency_measure_mode=self.tb_args.latency_measure_mode,
12711272
)
12721273
if {
12731274
"gpu_peak_mem",

0 commit comments

Comments
 (0)