diff --git a/tritonbench/utils/constants.py b/tritonbench/utils/constants.py new file mode 100644 index 000000000..e252aee3f --- /dev/null +++ b/tritonbench/utils/constants.py @@ -0,0 +1,4 @@ +DEFAULT_WARMUP = 25 +DEFAULT_REP = 100 +DEFAULT_QUANTILES = [0.5, 0.1, 0.9] +DEFAULT_SLEEP = 0.0 diff --git a/tritonbench/utils/parser.py b/tritonbench/utils/parser.py index aa7a60ae7..20bf89865 100644 --- a/tritonbench/utils/parser.py +++ b/tritonbench/utils/parser.py @@ -1,7 +1,7 @@ import argparse from tritonbench.utils.env_utils import AVAILABLE_PRECISIONS, is_fbcode -from tritonbench.utils.triton_op import DEFAULT_REP, DEFAULT_WARMUP +from tritonbench.utils.constants import DEFAULT_REP, DEFAULT_WARMUP def get_parser(args=None): diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 44167ae90..8ea2ad6d1 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -14,6 +14,7 @@ import sys import tempfile import time +import types from collections import defaultdict, OrderedDict from dataclasses import asdict, dataclass, fields @@ -35,6 +36,7 @@ ) from tritonbench.components.export import export_data +from tritonbench.utils.constants import (DEFAULT_WARMUP,DEFAULT_REP,DEFAULT_QUANTILES,DEFAULT_SLEEP) from tritonbench.utils.env_utils import ( apply_precision, is_fbcode, @@ -43,6 +45,7 @@ set_random_seed, ) from tritonbench.utils.input import input_cast +from tritonbench.utils.parser import get_parser from tritonbench.utils.path_utils import add_cmd_parameter, remove_cmd_parameter if is_hip(): @@ -77,11 +80,6 @@ class BenchmarkOperatorBackend: # ci = False implies enabled = False ci: bool = True - -DEFAULT_WARMUP = 25 -DEFAULT_REP = 100 -DEFAULT_QUANTILES = [0.5, 0.1, 0.9] -DEFAULT_SLEEP = 0.0 REGISTERED_BENCHMARKS: Dict[str, OrderedDict[str, BenchmarkOperatorBackend]] = {} REGISTERED_METRICS: defaultdict[str, List[str]] = defaultdict(list) OVERRIDDEN_METRICS: defaultdict[str, List[str]] = defaultdict(list) @@ -590,10 +588,11 @@ def register_benchmark( label: Optional[str] = None, ): def decorator(function): + op_name = ( - _find_op_name_from_module_path(function.__module__) - if not operator_name - else operator_name + operator_name + if operator_name + else _find_op_name_from_module_path(function.__module__) ) fn_name = function.__name__ if not func_name else func_name backend_config = BenchmarkOperatorBackend( @@ -667,6 +666,11 @@ def _has_and_true(attr): if _has_and_true("fwd_no_grad"): tb_args.mode = "fwd_no_grad" +def override_args(args_to_override): + parser = get_parser() + tb_args, extra_args = parser.parse_known_args(args_to_override) + return tb_args, extra_args + class BenchmarkOperator(metaclass=PostInitProcessor): mode: Mode = Mode.FWD @@ -692,11 +696,19 @@ class BenchmarkOperator(metaclass=PostInitProcessor): """ def __init__( - self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None + self, tb_args: argparse.Namespace = None, extra_args: Optional[List[str]] = None ): set_env() set_random_seed() - self.name = _find_op_name_from_module_path(self.__class__.__module__) + if extra_args and not tb_args: + tb_args, extra_args = override_args(extra_args) + elif not tb_args: + raise ValueError('no args selected. Either pass in argparse namespace or give list override') + + if tb_args.benchmark_name: + self.name = tb_args.benchmark_name + else: + self.name = _find_op_name_from_module_path(self.__class__.__module__) self._raw_extra_args = copy.deepcopy(extra_args) self.tb_args = tb_args self.add_production_shapes = ( @@ -807,6 +819,39 @@ def fwd_no_grad_fn(): setattr(fwd_no_grad_fn, "_name", bm_func_name) return fwd_no_grad_fn + + def set_input_iter(self, input_iter: Callable): + def input_decorator(input_iter): + def input_callable(self): + return input_iter() + return input_callable + self.get_input_iter = input_decorator(input_iter) + self.get_input_iter = input_decorator(input_iter).__get__(self, BenchmarkOperator) + self.input_iter = input_iter + self._available_num_inputs = sum(1 for _ in self.get_input_iter()) + self._num_inputs = self._available_num_inputs - self._input_id + + def add_benchmark( + self, + bm_callable: Callable, + operator_name: Optional[str] = None, + func_name: Optional[str] = None, + baseline: bool = False, + fwd_only: bool = False, + label: Optional[str] = None + ) -> None: + decorator_kwargs = { + "operator_name":operator_name or self.name, + "func_name":func_name, + "enabled":True, + "baseline":baseline, + "fwd_only":fwd_only, + "label":label + } + decorated_func = register_benchmark(**decorator_kwargs)(bm_callable) + bound_method = types.MethodType(decorated_func, self) + setattr(self, func_name or bm_callable.__name__, bound_method) + REGISTERED_BENCHMARKS[func_name] = bm_callable def run( self, @@ -959,9 +1004,10 @@ def get_bwd_fn(self, fwd_fn: Callable) -> Callable: def get_input_iter(self) -> Generator: """Return the dynamic input iterator for the model.""" - raise NotImplementedError( + logger.warning( "Each operator must implement its own input iterator." ) + return [] def get_grad_to_none(self, args): return None