diff --git a/tritonbench/components/do_bench/run.py b/tritonbench/components/do_bench/run.py index dc0a7090..2b1bae30 100644 --- a/tritonbench/components/do_bench/run.py +++ b/tritonbench/components/do_bench/run.py @@ -5,6 +5,7 @@ import torch import triton +from torch._inductor.runtime.benchmarking import benchmarker NS_TO_MS = 1e-6 @@ -125,6 +126,41 @@ def _summarize_statistics(times, quantiles, return_mode): return getattr(torch, return_mode)(times).item() +def _do_bench_inductor(fn, warmup, rep, grad_to_none=None): + """Measure latency using inductor benchmarker. + + Args: + warmup: Target warmup time in milliseconds (matches triton.testing.do_bench) + rep: Target total measurement time in milliseconds (matches triton.testing.do_bench) + grad_to_none: Tensors whose gradients should be cleared before each measurement + + Returns: + List of measured times in milliseconds. + """ + # First, estimate the runtime with a single measurement + estimate_ms = benchmarker.benchmark_gpu(fn, estimation_iters=5, benchmark_iters=10) + + # Calculate number of iterations based on target rep time + # Similar to how triton.testing.do_bench calculates iterations + if estimate_ms == 0: + n_repeat = 1000 # Default if function is very fast + else: + n_repeat = max(1, int(rep / estimate_ms)) + + # Collect multiple measurements like triton.testing.do_bench with return_mode='all' + times_ms = [] + for _ in range(n_repeat): + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + + # Measure only the function execution time + ms_time = benchmarker.benchmark_gpu(fn) + times_ms.append(ms_time) + + return times_ms + + def _do_bench_cpu( fn, warmup, rep=20, grad_to_none=None, quantiles=None, return_mode="mean" ): @@ -174,8 +210,13 @@ def do_bench_wrapper( device: str = "cuda", use_cuda_graphs: bool = False, bypass_fail: bool = False, + latency_measure_mode: str = "triton_do_bench", ) -> Optional[Latency]: - """Wrapper to triton's do_bench to gain latency.""" + """Wrapper to triton's do_bench to gain latency. + + Args: + latency_measure_mode: Either "triton_do_bench" (default) or "inductor_benchmarker" + """ try: if device == "cpu": return Latency( @@ -198,15 +239,25 @@ def do_bench_wrapper( ) ) else: - return Latency( - times=triton.testing.do_bench( - fn, - warmup=warmup, - rep=rep, - return_mode="all", - grad_to_none=grad_to_none, + if latency_measure_mode == "inductor_benchmarker": + return Latency( + times=_do_bench_inductor( + fn, + warmup=warmup, + rep=rep, + grad_to_none=grad_to_none, + ) + ) + else: # default to triton do_bench + return Latency( + times=triton.testing.do_bench( + fn, + warmup=warmup, + rep=rep, + return_mode="all", + grad_to_none=grad_to_none, + ) ) - ) except Exception as e: if not bypass_fail: raise e diff --git a/tritonbench/utils/parser.py b/tritonbench/utils/parser.py index 20bf8986..08f4e609 100644 --- a/tritonbench/utils/parser.py +++ b/tritonbench/utils/parser.py @@ -1,8 +1,9 @@ import argparse -from tritonbench.utils.env_utils import AVAILABLE_PRECISIONS, is_fbcode from tritonbench.utils.constants import DEFAULT_REP, DEFAULT_WARMUP +from tritonbench.utils.env_utils import AVAILABLE_PRECISIONS, is_fbcode + def get_parser(args=None): parser = argparse.ArgumentParser(allow_abbrev=False) @@ -185,6 +186,12 @@ def get_parser(args=None): parser.add_argument( "--cudagraph", action="store_true", help="Benchmark with CUDA graph." ) + parser.add_argument( + "--latency-measure-mode", + default="triton_do_bench", + choices=["triton_do_bench", "inductor_benchmarker"], + help="Method to measure latency: triton_do_bench (default) or inductor_benchmarker.", + ) parser.add_argument( "--isolate", action="store_true", diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index dd06cb9c..88b7d5f2 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -36,7 +36,12 @@ ) from tritonbench.components.export import export_data -from tritonbench.utils.constants import (DEFAULT_WARMUP,DEFAULT_REP,DEFAULT_QUANTILES,DEFAULT_SLEEP) +from tritonbench.utils.constants import ( + DEFAULT_QUANTILES, + DEFAULT_REP, + DEFAULT_SLEEP, + DEFAULT_WARMUP, +) from tritonbench.utils.env_utils import ( apply_precision, is_fbcode, @@ -80,6 +85,7 @@ class BenchmarkOperatorBackend: # ci = False implies enabled = False ci: bool = True + REGISTERED_BENCHMARKS: Dict[str, OrderedDict[str, BenchmarkOperatorBackend]] = {} REGISTERED_METRICS: defaultdict[str, List[str]] = defaultdict(list) OVERRIDDEN_METRICS: defaultdict[str, List[str]] = defaultdict(list) @@ -588,7 +594,6 @@ def register_benchmark( label: Optional[str] = None, ): def decorator(function): - op_name = ( operator_name if operator_name @@ -666,6 +671,7 @@ def _has_and_true(attr): if _has_and_true("fwd_no_grad"): tb_args.mode = "fwd_no_grad" + def override_args(args_to_override): parser = get_parser() tb_args, extra_args = parser.parse_known_args(args_to_override) @@ -703,7 +709,9 @@ def __init__( if extra_args and not tb_args: tb_args, extra_args = override_args(extra_args) elif not tb_args: - raise ValueError('no args selected. Either pass in argparse namespace or give list override') + raise ValueError( + "no args selected. Either pass in argparse namespace or give list override" + ) if tb_args.benchmark_name: self.name = tb_args.benchmark_name @@ -819,20 +827,28 @@ def fwd_no_grad_fn(): setattr(fwd_no_grad_fn, "_name", bm_func_name) return fwd_no_grad_fn - + def set_input_iter(self, input_iter: Callable): def input_decorator(input_iter): def input_callable(self): return input_iter() + return input_callable + self.get_input_iter = input_decorator(input_iter) - self.get_input_iter = input_decorator(input_iter).__get__(self, BenchmarkOperator) + self.get_input_iter = input_decorator(input_iter).__get__( + self, BenchmarkOperator + ) self.input_iter = input_iter self._available_num_inputs = sum(1 for _ in self.get_input_iter()) self._num_inputs = self._available_num_inputs - self._input_id - + def add_benchmark(self, bm_func_name: str, bm_callable: Callable): - decorator_kwargs = {"operator_name":self.name,"func_name":bm_func_name,"enabled":True} + decorator_kwargs = { + "operator_name": self.name, + "func_name": bm_func_name, + "enabled": True, + } decorated_func = register_benchmark(**decorator_kwargs)(bm_callable) bound_method = types.MethodType(decorated_func, self) setattr(self, bm_func_name or bm_callable.__name__, bound_method) @@ -989,9 +1005,7 @@ def get_bwd_fn(self, fwd_fn: Callable) -> Callable: def get_input_iter(self) -> Generator: """Return the dynamic input iterator for the model.""" - logger.warning( - "Each operator must implement its own input iterator." - ) + logger.warning("Each operator must implement its own input iterator.") return [] def get_grad_to_none(self, args): @@ -1268,6 +1282,7 @@ def _init_extra_metrics() -> Dict[str, Any]: device=self.device, use_cuda_graphs=self.use_cuda_graphs, bypass_fail=self.tb_args.bypass_fail, + latency_measure_mode=self.tb_args.latency_measure_mode, ) if { "gpu_peak_mem",