Skip to content

Commit 496c120

Browse files
authored
further additions of dynamic benchmarking
Differential Revision: D83537801 Pull Request resolved: #496
1 parent ee75188 commit 496c120

File tree

2 files changed

+92
-3
lines changed

2 files changed

+92
-3
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from typing import Callable, Dict, Generator, List, Optional
2+
3+
from tritonbench.utils.triton_op import BenchmarkOperator
4+
5+
6+
def dynamic_run(
7+
benchmarks: Dict[str, Callable],
8+
input_iter: Optional[Generator],
9+
**kwargs,
10+
) -> None:
11+
"""
12+
Run a list of benchmarks with a given set of inputs and kwargs.
13+
Kwargs in this case are the command-line arguments available in tritonbench.
14+
15+
Example:
16+
17+
def triton_add(x, y):
18+
...
19+
20+
def input_iter():
21+
size = 2**12
22+
x = torch.rand(size, device="cuda")
23+
y = torch.rand(size, device="cuda")
24+
yield x,y
25+
26+
benchmarks = {
27+
"triton_add": triton_add,
28+
"triton_add2": triton_add,
29+
}
30+
31+
dynamic_run(benchmarks=benchmarks, input_iter=input_iter, benchmark_name="vector_add")
32+
"""
33+
34+
# Convert kwargs into a list of command-line arguments
35+
arg_list = []
36+
for k, v in kwargs.items():
37+
key = f"--{k.replace('_', '-')}"
38+
arg_list.append(key)
39+
arg_list.append(str(v))
40+
41+
op = BenchmarkOperator(extra_args=arg_list)
42+
43+
op.set_input_iter(input_iter)
44+
45+
for k, v in benchmarks.items():
46+
op.add_benchmark(bm_func_name=k, bm_callable=v)
47+
48+
op.run()
49+
print(op.output)
50+
return op.output
51+
52+
53+
def dynamic_run_once(
54+
benchmarks: Dict[str, Callable], single_input: Optional[List], **kwargs
55+
):
56+
"""
57+
Run a list of benchmarks with a given set of inputs and kwargs.
58+
Kwargs in this case are the command-line arguments available in tritonbench.
59+
60+
The single_input is a list of arguments that will be passed to the benchmark function all together
61+
62+
Example:
63+
64+
def triton_add(x, y):
65+
...
66+
67+
benchmarks = {
68+
"triton_add": triton_add,
69+
"triton_add2": triton_add,
70+
}
71+
size = 2**12
72+
x = torch.rand(size, device="cuda")
73+
y = torch.rand(size, device="cuda")
74+
dynamic_run_once(benchmarks=benchmarks, single_input=[x, y], benchmark_name="vector_add")
75+
"""
76+
77+
def input_iterator(*args):
78+
def generator():
79+
yield args
80+
81+
return generator
82+
83+
input_generator = input_iterator(*single_input)
84+
output = dynamic_run(benchmarks, input_generator, **kwargs)
85+
return output

tritonbench/utils/triton_op.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -939,18 +939,22 @@ def input_callable(self):
939939
)
940940
self.input_iter = input_iter
941941
self._available_num_inputs = sum(1 for _ in self.get_input_iter())
942-
self._num_inputs = self._available_num_inputs - self._input_id
942+
self._num_inputs = self._available_num_inputs - len(self._input_ids)
943+
self._input_ids = [i for i in range(0, self._num_inputs)]
943944

944945
def add_benchmark(self, bm_func_name: str, bm_callable: Callable):
946+
def _inner(self, *args, **kwargs):
947+
return bm_callable(*args, **kwargs)
948+
945949
decorator_kwargs = {
946950
"operator_name": self.name,
947951
"func_name": bm_func_name,
948952
"enabled": True,
949953
}
950-
decorated_func = register_benchmark(**decorator_kwargs)(bm_callable)
954+
decorated_func = register_benchmark(**decorator_kwargs)(_inner)
951955
bound_method = types.MethodType(decorated_func, self)
952956
setattr(self, bm_func_name or bm_callable.__name__, bound_method)
953-
REGISTERED_BENCHMARKS[bm_func_name] = bm_callable
957+
REGISTERED_BENCHMARKS[bm_func_name] = _inner
954958

955959
def run(
956960
self,

0 commit comments

Comments
 (0)