Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions benchmarks/tagging/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def setup_tritonbench_cwd():
from tritonbench.operators import list_operators
from tritonbench.utils.operator_utils import get_backends_for_operator
from tritonbench.utils.run_utils import load_operator_by_args
from tritonbench.utils.triton_op import REGISTERED_BACKEND_TAGS

try:
from ast_analyzer import build_backend_callees, trace_callees
Expand Down Expand Up @@ -112,7 +113,35 @@ def apply_name_based_heuristics(backend_name, tags_dict):
return tags_dict


def prevalidate_backends(backend_edges):
def merge_decorator_tags(op_name, backend_name, tags_dict):
"""
Merge tags from @register_benchmark decorator with auto-detected tags.

Args:
op_name: The operator name
backend_name: The backend name
tags_dict: Dictionary with auto-detected tags, e.g., {"tags": ["pt2"]}
If None, will be created.

Returns:
Updated tags_dict with decorator tags merged
"""
if not tags_dict:
tags_dict = {"tags": []}
if "tags" not in tags_dict:
tags_dict["tags"] = []

# Get decorator tags if they exist
decorator_tags = REGISTERED_BACKEND_TAGS.get(op_name, {}).get(backend_name, [])
if decorator_tags:
# Merge decorator tags with auto-detected tags (remove duplicates)
all_tags = list(set(decorator_tags + tags_dict["tags"]))
tags_dict["tags"] = all_tags

return tags_dict


def prevalidate_backends(backend_edges, op_name=None):
op_with_tags = {}
# heuristic: do not search torch.nn, torch.compile, and xformers backends
for backend, callees in backend_edges.items():
Expand Down Expand Up @@ -140,6 +169,11 @@ def prevalidate_backends(backend_edges):
op_with_tags[backend] = apply_name_based_heuristics(
backend, op_with_tags[backend]
)
# Merge with decorator tags if available
if op_name:
op_with_tags[backend] = merge_decorator_tags(
op_name, backend, op_with_tags[backend]
)

return op_with_tags

Expand All @@ -164,7 +198,7 @@ def trace_op(op):
backends=backends,
)
assert len(backend_edges) == len(backends)
op_with_tags[op] = prevalidate_backends(backend_edges)
op_with_tags[op] = prevalidate_backends(backend_edges, op_name=op)
remaining_backends = [
backend for backend in backends if backend not in op_with_tags[op]
]
Expand All @@ -181,6 +215,10 @@ def trace_op(op):
op_with_tags[op][backend] = apply_name_based_heuristics(
backend, op_with_tags[op][backend]
)
# Merge with decorator tags
op_with_tags[op][backend] = merge_decorator_tags(
op, backend, op_with_tags[op][backend]
)
return op_with_tags


Expand Down
16 changes: 13 additions & 3 deletions tritonbench/operators/ADD_OPERATOR.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ The `operator.py` file needs to implement the following:
2. `get_input_iter`: This method should return an iterator of input examples for the
operator.
3. `@register_benchmark`: This decorator should be used to register the benchmarks for
the operator.
the operator. It supports the following parameters:
- `baseline` (bool): Mark this as the baseline implementation
- `enabled` (bool): Whether this backend is enabled
- `fwd_only` (bool): Whether this backend only supports forward pass
- `label` (str): Display label for this backend
- `tags` (List[str]): Tags for categorizing the backend (e.g., `['triton', 'custom']`)
4. `get_bwd_fn`: This method should return a callable that performs the backward pass
for the operator when needed.
5. `get_grad_to_none`: This method should be overridden to set the gradients to your argument for
Expand Down Expand Up @@ -79,7 +84,12 @@ class Operator(BenchmarkOperator):
def my_operator(self, input) -> Callable:
return lambda: self.model(input)

@register_benchmark()
def my_operator2(self, input) -> Callable:
@register_benchmark(tags=['triton'])
def my_operator_triton(self, input) -> Callable:
return lambda: kernel_wrapper(input)

@register_benchmark(tags=['custom', 'experimental'])
def my_operator_custom(self, input) -> Callable:
# Your custom implementation
return lambda: kernel_wrapper(input)
```
13 changes: 13 additions & 0 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,18 @@ class BenchmarkOperatorBackend:
# need to be tested in ci
# ci = False implies enabled = False
ci: bool = True
# tags for categorizing the backend (e.g., ['triton', 'pt2'])
tags: Optional[List[str]] = None


REGISTERED_BENCHMARKS: Dict[str, OrderedDict[str, BenchmarkOperatorBackend]] = {}
REGISTERED_METRICS: defaultdict[str, List[str]] = defaultdict(list)
OVERRIDDEN_METRICS: defaultdict[str, List[str]] = defaultdict(list)
REGISTERED_X_VALS: Dict[str, str] = {}
BASELINE_BENCHMARKS: Dict[str, str] = {}
# Store tags defined in @register_benchmark decorator
# Format: {operator_name: {backend_name: [tag1, tag2, ...]}}
REGISTERED_BACKEND_TAGS: Dict[str, Dict[str, List[str]]] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove the use of REGISTERED_BACKEND_TAGS? Instead, we can add it as a field of class BenchmarkOperatorBackend

BASELINE_SKIP_METRICS = {
"speedup",
"accuracy",
Expand Down Expand Up @@ -598,6 +603,7 @@ def register_benchmark(
enabled: bool = True,
fwd_only: bool = False,
label: Optional[str] = None,
tags: Optional[List[str]] = None,
):
def decorator(function):
op_name = (
Expand All @@ -612,13 +618,20 @@ def decorator(function):
baseline=baseline,
enabled=enabled,
fwd_only=fwd_only,
tags=tags,
)
if op_name not in REGISTERED_BENCHMARKS:
REGISTERED_BENCHMARKS[op_name] = OrderedDict()
REGISTERED_BENCHMARKS[op_name][fn_name] = backend_config
if backend_config.baseline:
BASELINE_BENCHMARKS[op_name] = fn_name

# Store tags in the global REGISTERED_BACKEND_TAGS dict
if tags:
if op_name not in REGISTERED_BACKEND_TAGS:
REGISTERED_BACKEND_TAGS[op_name] = {}
REGISTERED_BACKEND_TAGS[op_name][fn_name] = tags

def _inner(self, *args, **kwargs):
return function(self, *args, **kwargs)

Expand Down
Loading