diff --git a/benchmarks/tagging/run.py b/benchmarks/tagging/run.py index a782093bd..5e330aec5 100644 --- a/benchmarks/tagging/run.py +++ b/benchmarks/tagging/run.py @@ -42,6 +42,7 @@ def setup_tritonbench_cwd(): from tritonbench.operators import list_operators from tritonbench.utils.operator_utils import get_backends_for_operator from tritonbench.utils.run_utils import load_operator_by_args +from tritonbench.utils.triton_op import REGISTERED_BACKEND_TAGS try: from ast_analyzer import build_backend_callees, trace_callees @@ -112,7 +113,35 @@ def apply_name_based_heuristics(backend_name, tags_dict): return tags_dict -def prevalidate_backends(backend_edges): +def merge_decorator_tags(op_name, backend_name, tags_dict): + """ + Merge tags from @register_benchmark decorator with auto-detected tags. + + Args: + op_name: The operator name + backend_name: The backend name + tags_dict: Dictionary with auto-detected tags, e.g., {"tags": ["pt2"]} + If None, will be created. + + Returns: + Updated tags_dict with decorator tags merged + """ + if not tags_dict: + tags_dict = {"tags": []} + if "tags" not in tags_dict: + tags_dict["tags"] = [] + + # Get decorator tags if they exist + decorator_tags = REGISTERED_BACKEND_TAGS.get(op_name, {}).get(backend_name, []) + if decorator_tags: + # Merge decorator tags with auto-detected tags (remove duplicates) + all_tags = list(set(decorator_tags + tags_dict["tags"])) + tags_dict["tags"] = all_tags + + return tags_dict + + +def prevalidate_backends(backend_edges, op_name=None): op_with_tags = {} # heuristic: do not search torch.nn, torch.compile, and xformers backends for backend, callees in backend_edges.items(): @@ -140,6 +169,11 @@ def prevalidate_backends(backend_edges): op_with_tags[backend] = apply_name_based_heuristics( backend, op_with_tags[backend] ) + # Merge with decorator tags if available + if op_name: + op_with_tags[backend] = merge_decorator_tags( + op_name, backend, op_with_tags[backend] + ) return op_with_tags @@ -164,7 +198,7 @@ def trace_op(op): backends=backends, ) assert len(backend_edges) == len(backends) - op_with_tags[op] = prevalidate_backends(backend_edges) + op_with_tags[op] = prevalidate_backends(backend_edges, op_name=op) remaining_backends = [ backend for backend in backends if backend not in op_with_tags[op] ] @@ -181,6 +215,10 @@ def trace_op(op): op_with_tags[op][backend] = apply_name_based_heuristics( backend, op_with_tags[op][backend] ) + # Merge with decorator tags + op_with_tags[op][backend] = merge_decorator_tags( + op, backend, op_with_tags[op][backend] + ) return op_with_tags diff --git a/tritonbench/operators/ADD_OPERATOR.md b/tritonbench/operators/ADD_OPERATOR.md index 1aaff99db..bd03396fd 100644 --- a/tritonbench/operators/ADD_OPERATOR.md +++ b/tritonbench/operators/ADD_OPERATOR.md @@ -30,7 +30,12 @@ The `operator.py` file needs to implement the following: 2. `get_input_iter`: This method should return an iterator of input examples for the operator. 3. `@register_benchmark`: This decorator should be used to register the benchmarks for - the operator. + the operator. It supports the following parameters: + - `baseline` (bool): Mark this as the baseline implementation + - `enabled` (bool): Whether this backend is enabled + - `fwd_only` (bool): Whether this backend only supports forward pass + - `label` (str): Display label for this backend + - `tags` (List[str]): Tags for categorizing the backend (e.g., `['triton', 'custom']`) 4. `get_bwd_fn`: This method should return a callable that performs the backward pass for the operator when needed. 5. `get_grad_to_none`: This method should be overridden to set the gradients to your argument for @@ -79,7 +84,12 @@ class Operator(BenchmarkOperator): def my_operator(self, input) -> Callable: return lambda: self.model(input) - @register_benchmark() - def my_operator2(self, input) -> Callable: + @register_benchmark(tags=['triton']) + def my_operator_triton(self, input) -> Callable: + return lambda: kernel_wrapper(input) + + @register_benchmark(tags=['custom', 'experimental']) + def my_operator_custom(self, input) -> Callable: + # Your custom implementation return lambda: kernel_wrapper(input) ``` diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 2c697515d..1e17a6223 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -84,6 +84,8 @@ class BenchmarkOperatorBackend: # need to be tested in ci # ci = False implies enabled = False ci: bool = True + # tags for categorizing the backend (e.g., ['triton', 'pt2']) + tags: Optional[List[str]] = None REGISTERED_BENCHMARKS: Dict[str, OrderedDict[str, BenchmarkOperatorBackend]] = {} @@ -91,6 +93,9 @@ class BenchmarkOperatorBackend: OVERRIDDEN_METRICS: defaultdict[str, List[str]] = defaultdict(list) REGISTERED_X_VALS: Dict[str, str] = {} BASELINE_BENCHMARKS: Dict[str, str] = {} +# Store tags defined in @register_benchmark decorator +# Format: {operator_name: {backend_name: [tag1, tag2, ...]}} +REGISTERED_BACKEND_TAGS: Dict[str, Dict[str, List[str]]] = {} BASELINE_SKIP_METRICS = { "speedup", "accuracy", @@ -598,6 +603,7 @@ def register_benchmark( enabled: bool = True, fwd_only: bool = False, label: Optional[str] = None, + tags: Optional[List[str]] = None, ): def decorator(function): op_name = ( @@ -612,6 +618,7 @@ def decorator(function): baseline=baseline, enabled=enabled, fwd_only=fwd_only, + tags=tags, ) if op_name not in REGISTERED_BENCHMARKS: REGISTERED_BENCHMARKS[op_name] = OrderedDict() @@ -619,6 +626,12 @@ def decorator(function): if backend_config.baseline: BASELINE_BENCHMARKS[op_name] = fn_name + # Store tags in the global REGISTERED_BACKEND_TAGS dict + if tags: + if op_name not in REGISTERED_BACKEND_TAGS: + REGISTERED_BACKEND_TAGS[op_name] = {} + REGISTERED_BACKEND_TAGS[op_name][fn_name] = tags + def _inner(self, *args, **kwargs): return function(self, *args, **kwargs)