Skip to content

Commit 98945c7

Browse files
authored
adding dynamic benchmark loader
Differential Revision: D81133972 Pull Request resolved: #360
1 parent 87c4375 commit 98945c7

File tree

3 files changed

+47
-12
lines changed

3 files changed

+47
-12
lines changed

tritonbench/utils/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
DEFAULT_WARMUP = 25
2+
DEFAULT_REP = 100
3+
DEFAULT_QUANTILES = [0.5, 0.1, 0.9]
4+
DEFAULT_SLEEP = 0.0

tritonbench/utils/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import argparse
22

33
from tritonbench.utils.env_utils import AVAILABLE_PRECISIONS, is_fbcode
4-
from tritonbench.utils.triton_op import DEFAULT_REP, DEFAULT_WARMUP
4+
from tritonbench.utils.constants import DEFAULT_REP, DEFAULT_WARMUP
55

66

77
def get_parser(args=None):

tritonbench/utils/triton_op.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import sys
1515
import tempfile
1616
import time
17+
import types
1718

1819
from collections import defaultdict, OrderedDict
1920
from dataclasses import asdict, dataclass, fields
@@ -35,6 +36,7 @@
3536
)
3637
from tritonbench.components.export import export_data
3738

39+
from tritonbench.utils.constants import (DEFAULT_WARMUP,DEFAULT_REP,DEFAULT_QUANTILES,DEFAULT_SLEEP)
3840
from tritonbench.utils.env_utils import (
3941
apply_precision,
4042
is_fbcode,
@@ -43,6 +45,7 @@
4345
set_random_seed,
4446
)
4547
from tritonbench.utils.input import input_cast
48+
from tritonbench.utils.parser import get_parser
4649
from tritonbench.utils.path_utils import add_cmd_parameter, remove_cmd_parameter
4750

4851
if is_hip():
@@ -77,11 +80,6 @@ class BenchmarkOperatorBackend:
7780
# ci = False implies enabled = False
7881
ci: bool = True
7982

80-
81-
DEFAULT_WARMUP = 25
82-
DEFAULT_REP = 100
83-
DEFAULT_QUANTILES = [0.5, 0.1, 0.9]
84-
DEFAULT_SLEEP = 0.0
8583
REGISTERED_BENCHMARKS: Dict[str, OrderedDict[str, BenchmarkOperatorBackend]] = {}
8684
REGISTERED_METRICS: defaultdict[str, List[str]] = defaultdict(list)
8785
OVERRIDDEN_METRICS: defaultdict[str, List[str]] = defaultdict(list)
@@ -590,10 +588,11 @@ def register_benchmark(
590588
label: Optional[str] = None,
591589
):
592590
def decorator(function):
591+
593592
op_name = (
594-
_find_op_name_from_module_path(function.__module__)
595-
if not operator_name
596-
else operator_name
593+
operator_name
594+
if operator_name
595+
else _find_op_name_from_module_path(function.__module__)
597596
)
598597
fn_name = function.__name__ if not func_name else func_name
599598
backend_config = BenchmarkOperatorBackend(
@@ -667,6 +666,11 @@ def _has_and_true(attr):
667666
if _has_and_true("fwd_no_grad"):
668667
tb_args.mode = "fwd_no_grad"
669668

669+
def override_args(args_to_override):
670+
parser = get_parser()
671+
tb_args, extra_args = parser.parse_known_args(args_to_override)
672+
return tb_args, extra_args
673+
670674

671675
class BenchmarkOperator(metaclass=PostInitProcessor):
672676
mode: Mode = Mode.FWD
@@ -692,11 +696,19 @@ class BenchmarkOperator(metaclass=PostInitProcessor):
692696
"""
693697

694698
def __init__(
695-
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
699+
self, tb_args: argparse.Namespace = None, extra_args: Optional[List[str]] = None
696700
):
697701
set_env()
698702
set_random_seed()
699-
self.name = _find_op_name_from_module_path(self.__class__.__module__)
703+
if extra_args and not tb_args:
704+
tb_args, extra_args = override_args(extra_args)
705+
elif not tb_args:
706+
raise ValueError('no args selected. Either pass in argparse namespace or give list override')
707+
708+
if tb_args.benchmark_name:
709+
self.name = tb_args.benchmark_name
710+
else:
711+
self.name = _find_op_name_from_module_path(self.__class__.__module__)
700712
self._raw_extra_args = copy.deepcopy(extra_args)
701713
self.tb_args = tb_args
702714
self.add_production_shapes = (
@@ -807,6 +819,24 @@ def fwd_no_grad_fn():
807819

808820
setattr(fwd_no_grad_fn, "_name", bm_func_name)
809821
return fwd_no_grad_fn
822+
823+
def set_input_iter(self, input_iter: Callable):
824+
def input_decorator(input_iter):
825+
def input_callable(self):
826+
return input_iter()
827+
return input_callable
828+
self.get_input_iter = input_decorator(input_iter)
829+
self.get_input_iter = input_decorator(input_iter).__get__(self, BenchmarkOperator)
830+
self.input_iter = input_iter
831+
self._available_num_inputs = sum(1 for _ in self.get_input_iter())
832+
self._num_inputs = self._available_num_inputs - self._input_id
833+
834+
def add_benchmark(self, bm_func_name: str, bm_callable: Callable):
835+
decorator_kwargs = {"operator_name":self.name,"func_name":bm_func_name,"enabled":True}
836+
decorated_func = register_benchmark(**decorator_kwargs)(bm_callable)
837+
bound_method = types.MethodType(decorated_func, self)
838+
setattr(self, bm_func_name or bm_callable.__name__, bound_method)
839+
REGISTERED_BENCHMARKS[bm_func_name] = bm_callable
810840

811841
def run(
812842
self,
@@ -959,9 +989,10 @@ def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
959989

960990
def get_input_iter(self) -> Generator:
961991
"""Return the dynamic input iterator for the model."""
962-
raise NotImplementedError(
992+
logger.warning(
963993
"Each operator must implement its own input iterator."
964994
)
995+
return []
965996

966997
def get_grad_to_none(self, args):
967998
return None

0 commit comments

Comments
 (0)