Skip to content

Commit 206b93c

Browse files
authored
Add inductor_benchmarker as latency measurement option (#333)
1 parent 98945c7 commit 206b93c

File tree

3 files changed

+93
-20
lines changed

3 files changed

+93
-20
lines changed

tritonbench/components/do_bench/run.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77
import triton
8+
from torch._inductor.runtime.benchmarking import benchmarker
89

910
NS_TO_MS = 1e-6
1011

@@ -125,6 +126,41 @@ def _summarize_statistics(times, quantiles, return_mode):
125126
return getattr(torch, return_mode)(times).item()
126127

127128

129+
def _do_bench_inductor(fn, warmup, rep, grad_to_none=None):
130+
"""Measure latency using inductor benchmarker.
131+
132+
Args:
133+
warmup: Target warmup time in milliseconds (matches triton.testing.do_bench)
134+
rep: Target total measurement time in milliseconds (matches triton.testing.do_bench)
135+
grad_to_none: Tensors whose gradients should be cleared before each measurement
136+
137+
Returns:
138+
List of measured times in milliseconds.
139+
"""
140+
# First, estimate the runtime with a single measurement
141+
estimate_ms = benchmarker.benchmark_gpu(fn, estimation_iters=5, benchmark_iters=10)
142+
143+
# Calculate number of iterations based on target rep time
144+
# Similar to how triton.testing.do_bench calculates iterations
145+
if estimate_ms == 0:
146+
n_repeat = 1000 # Default if function is very fast
147+
else:
148+
n_repeat = max(1, int(rep / estimate_ms))
149+
150+
# Collect multiple measurements like triton.testing.do_bench with return_mode='all'
151+
times_ms = []
152+
for _ in range(n_repeat):
153+
if grad_to_none is not None:
154+
for x in grad_to_none:
155+
x.grad = None
156+
157+
# Measure only the function execution time
158+
ms_time = benchmarker.benchmark_gpu(fn)
159+
times_ms.append(ms_time)
160+
161+
return times_ms
162+
163+
128164
def _do_bench_cpu(
129165
fn, warmup, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"
130166
):
@@ -174,8 +210,13 @@ def do_bench_wrapper(
174210
device: str = "cuda",
175211
use_cuda_graphs: bool = False,
176212
bypass_fail: bool = False,
213+
latency_measure_mode: str = "triton_do_bench",
177214
) -> Optional[Latency]:
178-
"""Wrapper to triton's do_bench to gain latency."""
215+
"""Wrapper to triton's do_bench to gain latency.
216+
217+
Args:
218+
latency_measure_mode: Either "triton_do_bench" (default) or "inductor_benchmarker"
219+
"""
179220
try:
180221
if device == "cpu":
181222
return Latency(
@@ -198,15 +239,25 @@ def do_bench_wrapper(
198239
)
199240
)
200241
else:
201-
return Latency(
202-
times=triton.testing.do_bench(
203-
fn,
204-
warmup=warmup,
205-
rep=rep,
206-
return_mode="all",
207-
grad_to_none=grad_to_none,
242+
if latency_measure_mode == "inductor_benchmarker":
243+
return Latency(
244+
times=_do_bench_inductor(
245+
fn,
246+
warmup=warmup,
247+
rep=rep,
248+
grad_to_none=grad_to_none,
249+
)
250+
)
251+
else: # default to triton do_bench
252+
return Latency(
253+
times=triton.testing.do_bench(
254+
fn,
255+
warmup=warmup,
256+
rep=rep,
257+
return_mode="all",
258+
grad_to_none=grad_to_none,
259+
)
208260
)
209-
)
210261
except Exception as e:
211262
if not bypass_fail:
212263
raise e

tritonbench/utils/parser.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import argparse
22

3-
from tritonbench.utils.env_utils import AVAILABLE_PRECISIONS, is_fbcode
43
from tritonbench.utils.constants import DEFAULT_REP, DEFAULT_WARMUP
54

5+
from tritonbench.utils.env_utils import AVAILABLE_PRECISIONS, is_fbcode
6+
67

78
def get_parser(args=None):
89
parser = argparse.ArgumentParser(allow_abbrev=False)
@@ -185,6 +186,12 @@ def get_parser(args=None):
185186
parser.add_argument(
186187
"--cudagraph", action="store_true", help="Benchmark with CUDA graph."
187188
)
189+
parser.add_argument(
190+
"--latency-measure-mode",
191+
default="triton_do_bench",
192+
choices=["triton_do_bench", "inductor_benchmarker"],
193+
help="Method to measure latency: triton_do_bench (default) or inductor_benchmarker.",
194+
)
188195
parser.add_argument(
189196
"--isolate",
190197
action="store_true",

tritonbench/utils/triton_op.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@
3636
)
3737
from tritonbench.components.export import export_data
3838

39-
from tritonbench.utils.constants import (DEFAULT_WARMUP,DEFAULT_REP,DEFAULT_QUANTILES,DEFAULT_SLEEP)
39+
from tritonbench.utils.constants import (
40+
DEFAULT_QUANTILES,
41+
DEFAULT_REP,
42+
DEFAULT_SLEEP,
43+
DEFAULT_WARMUP,
44+
)
4045
from tritonbench.utils.env_utils import (
4146
apply_precision,
4247
is_fbcode,
@@ -80,6 +85,7 @@ class BenchmarkOperatorBackend:
8085
# ci = False implies enabled = False
8186
ci: bool = True
8287

88+
8389
REGISTERED_BENCHMARKS: Dict[str, OrderedDict[str, BenchmarkOperatorBackend]] = {}
8490
REGISTERED_METRICS: defaultdict[str, List[str]] = defaultdict(list)
8591
OVERRIDDEN_METRICS: defaultdict[str, List[str]] = defaultdict(list)
@@ -588,7 +594,6 @@ def register_benchmark(
588594
label: Optional[str] = None,
589595
):
590596
def decorator(function):
591-
592597
op_name = (
593598
operator_name
594599
if operator_name
@@ -666,6 +671,7 @@ def _has_and_true(attr):
666671
if _has_and_true("fwd_no_grad"):
667672
tb_args.mode = "fwd_no_grad"
668673

674+
669675
def override_args(args_to_override):
670676
parser = get_parser()
671677
tb_args, extra_args = parser.parse_known_args(args_to_override)
@@ -703,7 +709,9 @@ def __init__(
703709
if extra_args and not tb_args:
704710
tb_args, extra_args = override_args(extra_args)
705711
elif not tb_args:
706-
raise ValueError('no args selected. Either pass in argparse namespace or give list override')
712+
raise ValueError(
713+
"no args selected. Either pass in argparse namespace or give list override"
714+
)
707715

708716
if tb_args.benchmark_name:
709717
self.name = tb_args.benchmark_name
@@ -819,20 +827,28 @@ def fwd_no_grad_fn():
819827

820828
setattr(fwd_no_grad_fn, "_name", bm_func_name)
821829
return fwd_no_grad_fn
822-
830+
823831
def set_input_iter(self, input_iter: Callable):
824832
def input_decorator(input_iter):
825833
def input_callable(self):
826834
return input_iter()
835+
827836
return input_callable
837+
828838
self.get_input_iter = input_decorator(input_iter)
829-
self.get_input_iter = input_decorator(input_iter).__get__(self, BenchmarkOperator)
839+
self.get_input_iter = input_decorator(input_iter).__get__(
840+
self, BenchmarkOperator
841+
)
830842
self.input_iter = input_iter
831843
self._available_num_inputs = sum(1 for _ in self.get_input_iter())
832844
self._num_inputs = self._available_num_inputs - self._input_id
833-
845+
834846
def add_benchmark(self, bm_func_name: str, bm_callable: Callable):
835-
decorator_kwargs = {"operator_name":self.name,"func_name":bm_func_name,"enabled":True}
847+
decorator_kwargs = {
848+
"operator_name": self.name,
849+
"func_name": bm_func_name,
850+
"enabled": True,
851+
}
836852
decorated_func = register_benchmark(**decorator_kwargs)(bm_callable)
837853
bound_method = types.MethodType(decorated_func, self)
838854
setattr(self, bm_func_name or bm_callable.__name__, bound_method)
@@ -989,9 +1005,7 @@ def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
9891005

9901006
def get_input_iter(self) -> Generator:
9911007
"""Return the dynamic input iterator for the model."""
992-
logger.warning(
993-
"Each operator must implement its own input iterator."
994-
)
1008+
logger.warning("Each operator must implement its own input iterator.")
9951009
return []
9961010

9971011
def get_grad_to_none(self, args):
@@ -1268,6 +1282,7 @@ def _init_extra_metrics() -> Dict[str, Any]:
12681282
device=self.device,
12691283
use_cuda_graphs=self.use_cuda_graphs,
12701284
bypass_fail=self.tb_args.bypass_fail,
1285+
latency_measure_mode=self.tb_args.latency_measure_mode,
12711286
)
12721287
if {
12731288
"gpu_peak_mem",

0 commit comments

Comments
 (0)