From 25410108283a77c84b6cf57f0ee29465f15223dd Mon Sep 17 00:00:00 2001 From: jialiangqu Date: Sun, 23 Nov 2025 23:37:44 +0000 Subject: [PATCH] Add tags parameter support to @register_benchmark decorator Implements the feature requested in PR #647 comments - allows users to specify tags directly in the decorator like @register_benchmark(tags=['triton', 'custom']). Tags from the decorator are merged with auto-detected tags (from AST analysis and name-based heuristics). Added REGISTERED_BACKEND_TAGS global dict to store decorator tags, and updated the tagging system in run.py to merge all tag sources. Updated documentation with examples showing how to use the new tags parameter. --- benchmarks/tagging/run.py | 42 +++++++++++++++++++++++++-- tritonbench/operators/ADD_OPERATOR.md | 16 ++++++++-- tritonbench/utils/triton_op.py | 13 +++++++++ 3 files changed, 66 insertions(+), 5 deletions(-) diff --git a/benchmarks/tagging/run.py b/benchmarks/tagging/run.py index a782093bd..5e330aec5 100644 --- a/benchmarks/tagging/run.py +++ b/benchmarks/tagging/run.py @@ -42,6 +42,7 @@ def setup_tritonbench_cwd(): from tritonbench.operators import list_operators from tritonbench.utils.operator_utils import get_backends_for_operator from tritonbench.utils.run_utils import load_operator_by_args +from tritonbench.utils.triton_op import REGISTERED_BACKEND_TAGS try: from ast_analyzer import build_backend_callees, trace_callees @@ -112,7 +113,35 @@ def apply_name_based_heuristics(backend_name, tags_dict): return tags_dict -def prevalidate_backends(backend_edges): +def merge_decorator_tags(op_name, backend_name, tags_dict): + """ + Merge tags from @register_benchmark decorator with auto-detected tags. + + Args: + op_name: The operator name + backend_name: The backend name + tags_dict: Dictionary with auto-detected tags, e.g., {"tags": ["pt2"]} + If None, will be created. + + Returns: + Updated tags_dict with decorator tags merged + """ + if not tags_dict: + tags_dict = {"tags": []} + if "tags" not in tags_dict: + tags_dict["tags"] = [] + + # Get decorator tags if they exist + decorator_tags = REGISTERED_BACKEND_TAGS.get(op_name, {}).get(backend_name, []) + if decorator_tags: + # Merge decorator tags with auto-detected tags (remove duplicates) + all_tags = list(set(decorator_tags + tags_dict["tags"])) + tags_dict["tags"] = all_tags + + return tags_dict + + +def prevalidate_backends(backend_edges, op_name=None): op_with_tags = {} # heuristic: do not search torch.nn, torch.compile, and xformers backends for backend, callees in backend_edges.items(): @@ -140,6 +169,11 @@ def prevalidate_backends(backend_edges): op_with_tags[backend] = apply_name_based_heuristics( backend, op_with_tags[backend] ) + # Merge with decorator tags if available + if op_name: + op_with_tags[backend] = merge_decorator_tags( + op_name, backend, op_with_tags[backend] + ) return op_with_tags @@ -164,7 +198,7 @@ def trace_op(op): backends=backends, ) assert len(backend_edges) == len(backends) - op_with_tags[op] = prevalidate_backends(backend_edges) + op_with_tags[op] = prevalidate_backends(backend_edges, op_name=op) remaining_backends = [ backend for backend in backends if backend not in op_with_tags[op] ] @@ -181,6 +215,10 @@ def trace_op(op): op_with_tags[op][backend] = apply_name_based_heuristics( backend, op_with_tags[op][backend] ) + # Merge with decorator tags + op_with_tags[op][backend] = merge_decorator_tags( + op, backend, op_with_tags[op][backend] + ) return op_with_tags diff --git a/tritonbench/operators/ADD_OPERATOR.md b/tritonbench/operators/ADD_OPERATOR.md index 1aaff99db..bd03396fd 100644 --- a/tritonbench/operators/ADD_OPERATOR.md +++ b/tritonbench/operators/ADD_OPERATOR.md @@ -30,7 +30,12 @@ The `operator.py` file needs to implement the following: 2. `get_input_iter`: This method should return an iterator of input examples for the operator. 3. `@register_benchmark`: This decorator should be used to register the benchmarks for - the operator. + the operator. It supports the following parameters: + - `baseline` (bool): Mark this as the baseline implementation + - `enabled` (bool): Whether this backend is enabled + - `fwd_only` (bool): Whether this backend only supports forward pass + - `label` (str): Display label for this backend + - `tags` (List[str]): Tags for categorizing the backend (e.g., `['triton', 'custom']`) 4. `get_bwd_fn`: This method should return a callable that performs the backward pass for the operator when needed. 5. `get_grad_to_none`: This method should be overridden to set the gradients to your argument for @@ -79,7 +84,12 @@ class Operator(BenchmarkOperator): def my_operator(self, input) -> Callable: return lambda: self.model(input) - @register_benchmark() - def my_operator2(self, input) -> Callable: + @register_benchmark(tags=['triton']) + def my_operator_triton(self, input) -> Callable: + return lambda: kernel_wrapper(input) + + @register_benchmark(tags=['custom', 'experimental']) + def my_operator_custom(self, input) -> Callable: + # Your custom implementation return lambda: kernel_wrapper(input) ``` diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 2c697515d..1e17a6223 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -84,6 +84,8 @@ class BenchmarkOperatorBackend: # need to be tested in ci # ci = False implies enabled = False ci: bool = True + # tags for categorizing the backend (e.g., ['triton', 'pt2']) + tags: Optional[List[str]] = None REGISTERED_BENCHMARKS: Dict[str, OrderedDict[str, BenchmarkOperatorBackend]] = {} @@ -91,6 +93,9 @@ class BenchmarkOperatorBackend: OVERRIDDEN_METRICS: defaultdict[str, List[str]] = defaultdict(list) REGISTERED_X_VALS: Dict[str, str] = {} BASELINE_BENCHMARKS: Dict[str, str] = {} +# Store tags defined in @register_benchmark decorator +# Format: {operator_name: {backend_name: [tag1, tag2, ...]}} +REGISTERED_BACKEND_TAGS: Dict[str, Dict[str, List[str]]] = {} BASELINE_SKIP_METRICS = { "speedup", "accuracy", @@ -598,6 +603,7 @@ def register_benchmark( enabled: bool = True, fwd_only: bool = False, label: Optional[str] = None, + tags: Optional[List[str]] = None, ): def decorator(function): op_name = ( @@ -612,6 +618,7 @@ def decorator(function): baseline=baseline, enabled=enabled, fwd_only=fwd_only, + tags=tags, ) if op_name not in REGISTERED_BENCHMARKS: REGISTERED_BENCHMARKS[op_name] = OrderedDict() @@ -619,6 +626,12 @@ def decorator(function): if backend_config.baseline: BASELINE_BENCHMARKS[op_name] = fn_name + # Store tags in the global REGISTERED_BACKEND_TAGS dict + if tags: + if op_name not in REGISTERED_BACKEND_TAGS: + REGISTERED_BACKEND_TAGS[op_name] = {} + REGISTERED_BACKEND_TAGS[op_name][fn_name] = tags + def _inner(self, *args, **kwargs): return function(self, *args, **kwargs)