14
14
import sys
15
15
import tempfile
16
16
import time
17
+ import types
17
18
18
19
from collections import defaultdict , OrderedDict
19
20
from dataclasses import asdict , dataclass , fields
35
36
)
36
37
from tritonbench .components .export import export_data
37
38
39
+ from tritonbench .utils .constants import (DEFAULT_WARMUP ,DEFAULT_REP ,DEFAULT_QUANTILES ,DEFAULT_SLEEP )
38
40
from tritonbench .utils .env_utils import (
39
41
apply_precision ,
40
42
is_fbcode ,
43
45
set_random_seed ,
44
46
)
45
47
from tritonbench .utils .input import input_cast
48
+ from tritonbench .utils .parser import get_parser
46
49
from tritonbench .utils .path_utils import add_cmd_parameter , remove_cmd_parameter
47
50
48
51
if is_hip ():
@@ -77,11 +80,6 @@ class BenchmarkOperatorBackend:
77
80
# ci = False implies enabled = False
78
81
ci : bool = True
79
82
80
-
81
- DEFAULT_WARMUP = 25
82
- DEFAULT_REP = 100
83
- DEFAULT_QUANTILES = [0.5 , 0.1 , 0.9 ]
84
- DEFAULT_SLEEP = 0.0
85
83
REGISTERED_BENCHMARKS : Dict [str , OrderedDict [str , BenchmarkOperatorBackend ]] = {}
86
84
REGISTERED_METRICS : defaultdict [str , List [str ]] = defaultdict (list )
87
85
OVERRIDDEN_METRICS : defaultdict [str , List [str ]] = defaultdict (list )
@@ -590,10 +588,11 @@ def register_benchmark(
590
588
label : Optional [str ] = None ,
591
589
):
592
590
def decorator (function ):
591
+
593
592
op_name = (
594
- _find_op_name_from_module_path ( function . __module__ )
595
- if not operator_name
596
- else operator_name
593
+ operator_name
594
+ if operator_name
595
+ else _find_op_name_from_module_path ( function . __module__ )
597
596
)
598
597
fn_name = function .__name__ if not func_name else func_name
599
598
backend_config = BenchmarkOperatorBackend (
@@ -667,6 +666,11 @@ def _has_and_true(attr):
667
666
if _has_and_true ("fwd_no_grad" ):
668
667
tb_args .mode = "fwd_no_grad"
669
668
669
+ def override_args (args_to_override ):
670
+ parser = get_parser ()
671
+ tb_args , extra_args = parser .parse_known_args (args_to_override )
672
+ return tb_args , extra_args
673
+
670
674
671
675
class BenchmarkOperator (metaclass = PostInitProcessor ):
672
676
mode : Mode = Mode .FWD
@@ -692,11 +696,19 @@ class BenchmarkOperator(metaclass=PostInitProcessor):
692
696
"""
693
697
694
698
def __init__ (
695
- self , tb_args : argparse .Namespace , extra_args : Optional [List [str ]] = None
699
+ self , tb_args : argparse .Namespace = None , extra_args : Optional [List [str ]] = None
696
700
):
697
701
set_env ()
698
702
set_random_seed ()
699
- self .name = _find_op_name_from_module_path (self .__class__ .__module__ )
703
+ if extra_args and not tb_args :
704
+ tb_args , extra_args = override_args (extra_args )
705
+ elif not tb_args :
706
+ raise ValueError ('no args selected. Either pass in argparse namespace or give list override' )
707
+
708
+ if tb_args .benchmark_name :
709
+ self .name = tb_args .benchmark_name
710
+ else :
711
+ self .name = _find_op_name_from_module_path (self .__class__ .__module__ )
700
712
self ._raw_extra_args = copy .deepcopy (extra_args )
701
713
self .tb_args = tb_args
702
714
self .add_production_shapes = (
@@ -807,6 +819,24 @@ def fwd_no_grad_fn():
807
819
808
820
setattr (fwd_no_grad_fn , "_name" , bm_func_name )
809
821
return fwd_no_grad_fn
822
+
823
+ def set_input_iter (self , input_iter : Callable ):
824
+ def input_decorator (input_iter ):
825
+ def input_callable (self ):
826
+ return input_iter ()
827
+ return input_callable
828
+ self .get_input_iter = input_decorator (input_iter )
829
+ self .get_input_iter = input_decorator (input_iter ).__get__ (self , BenchmarkOperator )
830
+ self .input_iter = input_iter
831
+ self ._available_num_inputs = sum (1 for _ in self .get_input_iter ())
832
+ self ._num_inputs = self ._available_num_inputs - self ._input_id
833
+
834
+ def add_benchmark (self , bm_func_name : str , bm_callable : Callable ):
835
+ decorator_kwargs = {"operator_name" :self .name ,"func_name" :bm_func_name ,"enabled" :True }
836
+ decorated_func = register_benchmark (** decorator_kwargs )(bm_callable )
837
+ bound_method = types .MethodType (decorated_func , self )
838
+ setattr (self , bm_func_name or bm_callable .__name__ , bound_method )
839
+ REGISTERED_BENCHMARKS [bm_func_name ] = bm_callable
810
840
811
841
def run (
812
842
self ,
@@ -959,9 +989,10 @@ def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
959
989
960
990
def get_input_iter (self ) -> Generator :
961
991
"""Return the dynamic input iterator for the model."""
962
- raise NotImplementedError (
992
+ logger . warning (
963
993
"Each operator must implement its own input iterator."
964
994
)
995
+ return []
965
996
966
997
def get_grad_to_none (self , args ):
967
998
return None
0 commit comments