Skip to content

Commit 0bb2da4

Browse files
authored
Migrate the grouped_gemm inputs from qwen3
Differential Revision: D84006584 Pull Request resolved: #522
1 parent 48b9910 commit 0bb2da4

File tree

5 files changed

+329
-9
lines changed

5 files changed

+329
-9
lines changed
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
{
2+
"grouped_gemm": [
3+
{
4+
"count": 1,
5+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(5120, 768)']})"
6+
},
7+
{
8+
"count": 1,
9+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(2048, 2048)', '(5120, 2048)']})"
10+
},
11+
{
12+
"count": 1,
13+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(7168, 2048)', '(8192, 2048)']})"
14+
},
15+
{
16+
"count": 1,
17+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(16, 768)', '(16, 768)']})"
18+
},
19+
{
20+
"count": 1,
21+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(6144, 2048)', '(4096, 2048)']})"
22+
},
23+
{
24+
"count": 1,
25+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(5120, 2048)', '(8192, 2048)', '(7168, 2048)']})"
26+
},
27+
{
28+
"count": 1,
29+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(1024, 2048)', '(2048, 2048)', '(7168, 2048)', '(1024, 2048)']})"
30+
},
31+
{
32+
"count": 1,
33+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(1024, 2048)', '(8192, 2048)', '(5120, 2048)']})"
34+
},
35+
{
36+
"count": 1,
37+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(1024, 2048)', '(7168, 2048)', '(7168, 2048)']})"
38+
},
39+
{
40+
"count": 1,
41+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(4096, 2048)', '(8192, 2048)', '(1024, 2048)', '(1024, 2048)']})"
42+
},
43+
{
44+
"count": 1,
45+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(1024, 768)', '(8192, 768)', '(5120, 768)']})"
46+
},
47+
{
48+
"count": 1,
49+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(3072, 2048)', '(8192, 2048)', '(1024, 2048)', '(1024, 2048)', '(1024, 2048)']})"
50+
},
51+
{
52+
"count": 1,
53+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(8192, 2048)', '(1024, 2048)', '(8192, 2048)', '(1024, 2048)']})"
54+
},
55+
{
56+
"count": 1,
57+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(5120, 768)', '(7168, 768)', '(7168, 768)', '(2048, 768)']})"
58+
},
59+
{
60+
"count": 1,
61+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(5120, 768)', '(8192, 768)', '(7168, 768)']})"
62+
},
63+
{
64+
"count": 1,
65+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(4096, 2048)', '(1024, 2048)', '(2048, 2048)', '(8192, 2048)', '(1024, 2048)']})"
66+
},
67+
{
68+
"count": 1,
69+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(3072, 768)', '(1024, 768)', '(7168, 768)', '(8192, 768)']})"
70+
},
71+
{
72+
"count": 1,
73+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(4096, 768)', '(1024, 768)', '(2048, 768)', '(8192, 768)', '(1024, 768)']})"
74+
},
75+
{
76+
"count": 1,
77+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(2048, 2048)', '(7168, 2048)', '(1024, 2048)', '(5120, 2048)']})"
78+
},
79+
{
80+
"count": 1,
81+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(1024, 768)', '(7168, 768)', '(4096, 768)', '(5120, 768)']})"
82+
},
83+
{
84+
"count": 1,
85+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(1024, 2048)', '(3072, 2048)', '(1024, 2048)', '(2048, 2048)', '(4096, 2048)', '(1024, 2048)']})"
86+
},
87+
{
88+
"count": 1,
89+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(5120, 768)', '(1024, 768)', '(7168, 768)', '(3072, 768)', '(7168, 768)', '(1024, 768)']})"
90+
},
91+
{
92+
"count": 1,
93+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(6144, 768)', '(1024, 768)', '(1024, 768)', '(2048, 768)', '(1024, 768)', '(4096, 768)', '(1024, 768)', '(1024, 768)']})"
94+
},
95+
{
96+
"count": 1,
97+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(2048, 2048)', '(7168, 2048)', '(1024, 2048)', '(4096, 2048)']})"
98+
},
99+
{
100+
"count": 1,
101+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(2048, 768)', '(7168, 768)', '(1024, 768)', '(4096, 768)']})"
102+
},
103+
{
104+
"count": 1,
105+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(1024, 768)', '(7168, 768)', '(5120, 768)', '(8192, 768)']})"
106+
},
107+
{
108+
"count": 1,
109+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(5120, 2048)', '(1024, 2048)', '(7168, 2048)', '(3072, 2048)', '(7168, 2048)', '(1024, 2048)']})"
110+
},
111+
{
112+
"count": 1,
113+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(5120, 2048)', '(3072, 2048)', '(6144, 2048)', '(1024, 2048)']})"
114+
},
115+
{
116+
"count": 1,
117+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(2048, 768)', '(3072, 768)', '(8192, 768)', '(1024, 768)', '(1024, 768)', '(1024, 768)']})"
118+
},
119+
{
120+
"count": 1,
121+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(1024, 768)', '(3072, 768)', '(2048, 768)', '(6144, 768)']})"
122+
},
123+
{
124+
"count": 1,
125+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(8192, 2048)', '(1024, 2048)', '(7168, 2048)', '(6144, 2048)', '(7168, 2048)']})"
126+
},
127+
{
128+
"count": 1,
129+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(2048, 768)', '(1024, 768)', '(3072, 768)', '(2048, 768)', '(4096, 768)']})"
130+
},
131+
{
132+
"count": 1,
133+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(2048, 2048)', '(3072, 2048)', '(8192, 2048)', '(1024, 2048)', '(1024, 2048)', '(1024, 2048)']})"
134+
},
135+
{
136+
"count": 1,
137+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(1024, 2048)', '(7168, 2048)', '(4096, 2048)', '(5120, 2048)']})"
138+
},
139+
{
140+
"count": 1,
141+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(3072, 2048)', '(2048, 2048)', '(1024, 2048)', '(4096, 2048)', '(3072, 2048)', '(7168, 2048)']})"
142+
},
143+
{
144+
"count": 1,
145+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(3072, 2048)', '(1024, 2048)', '(4096, 2048)', '(3072, 2048)', '(1024, 2048)', '(1024, 2048)']})"
146+
},
147+
{
148+
"count": 1,
149+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(2048, 2048)', '(5120, 2048)', '(4096, 2048)']})"
150+
},
151+
{
152+
"count": 1,
153+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(6144, 2048)', '(1024, 2048)', '(1024, 2048)', '(1024, 2048)', '(3072, 2048)']})"
154+
},
155+
{
156+
"count": 1,
157+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(1024, 2048)', '(2048, 2048)', '(5120, 2048)']})"
158+
},
159+
{
160+
"count": 1,
161+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(1024, 768)', '(1024, 768)', '(4096, 768)']})"
162+
},
163+
{
164+
"count": 1,
165+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(3072, 768)', '(2048, 768)', '(6144, 768)']})"
166+
},
167+
{
168+
"count": 1,
169+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(1024, 2048)', '(6144, 2048)', '(1024, 2048)', '(2048, 2048)', '(2048, 2048)']})"
170+
},
171+
{
172+
"count": 1,
173+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(1024, 768)', '(2048, 768)', '(5120, 768)']})"
174+
},
175+
{
176+
"count": 1,
177+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(2048, 768)', '(1024, 768)', '(3072, 768)']})"
178+
},
179+
{
180+
"count": 1,
181+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(4096, 2048)', '(2048, 2048)', '(1024, 2048)']})"
182+
},
183+
{
184+
"count": 1,
185+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(2048, 2048)', '(1024, 2048)', '(3072, 2048)']})"
186+
},
187+
{
188+
"count": 1,
189+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(5120, 2048)', '(1024, 2048)', '(1024, 2048)']})"
190+
},
191+
{
192+
"count": 1,
193+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(3072, 2048)', '(2048, 2048)', '(5120, 2048)']})"
194+
},
195+
{
196+
"count": 1,
197+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(1024, 2048)', '(1024, 2048)', '(4096, 2048)']})"
198+
},
199+
{
200+
"count": 1,
201+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(1024, 768)', '(6144, 768)', '(1024, 768)', '(2048, 768)', '(2048, 768)']})"
202+
},
203+
{
204+
"count": 1,
205+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(5120, 768)', '(1024, 768)', '(1024, 768)']})"
206+
},
207+
{
208+
"count": 1,
209+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(4096, 768)', '(2048, 768)', '(1024, 768)']})"
210+
},
211+
{
212+
"count": 1,
213+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(3072, 768)', '(2048, 768)', '(5120, 768)']})"
214+
},
215+
{
216+
"count": 1,
217+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(2048, 768)', '(5120, 768)', '(4096, 768)']})"
218+
},
219+
{
220+
"count": 1,
221+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(3072, 2048)', '(2048, 2048)', '(6144, 2048)']})"
222+
},
223+
{
224+
"count": 1,
225+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(3072, 768)', '(1024, 768)', '(4096, 768)', '(3072, 768)', '(1024, 768)', '(1024, 768)']})"
226+
},
227+
{
228+
"count": 1,
229+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(768, 768)', '(640, 768)', '(64, 768)', '(1024, 768)', '(320, 768)', '(256, 768)', '(256, 768)', '(128, 768)', '(960, 768)', '(704, 768)', '(640, 768)', '(640, 768)', '(128, 768)', '(576, 768)', '(256, 768)', '(512, 768)']})"
230+
},
231+
{
232+
"count": 1,
233+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(256, 2048)', '(512, 2048)', '(64, 2048)', '(832, 2048)', '(576, 2048)', '(512, 2048)', '(704, 2048)', '(448, 2048)', '(448, 2048)', '(2560, 2048)', '(1152, 2048)', '(640, 2048)', '(384, 2048)', '(64, 2048)', '(384, 2048)', '(1024, 2048)']})"
234+
},
235+
{
236+
"count": 1,
237+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(320, 2048)', '(384, 2048)', '(320, 2048)', '(128, 2048)', '(64, 2048)', '(832, 2048)', '(128, 2048)', '(832, 2048)', '(320, 2048)', '(448, 2048)', '(384, 2048)', '(1024, 2048)', '(704, 2048)', '(960, 2048)', '(768, 2048)', '(320, 2048)']})"
238+
},
239+
{
240+
"count": 1,
241+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(1024, 768)', '(384, 768)', '(448, 768)', '(768, 768)', '(640, 768)', '(448, 768)', '(640, 768)', '(1152, 768)', '(704, 768)', '(384, 768)', '(64, 768)', '(384, 768)', '(512, 768)', '(128, 768)', '(704, 768)', '(1600, 768)']})"
242+
},
243+
{
244+
"count": 1,
245+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(128, 2048)', '(128, 2048)', '(256, 2048)', '(448, 2048)', '(704, 2048)', '(1152, 2048)', '(192, 2048)', '(512, 2048)', '(960, 2048)', '(384, 2048)', '(256, 2048)', '(384, 2048)', '(1600, 2048)', '(128, 2048)', '(512, 2048)', '(1600, 2048)']})"
246+
},
247+
{
248+
"count": 1,
249+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(64, 768)', '(64, 768)', '(64, 768)', '(128, 768)', '(192, 768)', '(64, 768)', '(64, 768)', '(64, 768)', '(64, 768)']})"
250+
},
251+
{
252+
"count": 1,
253+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(64, 2048)', '(64, 2048)', '(64, 2048)', '(128, 2048)', '(192, 2048)', '(64, 2048)', '(64, 2048)', '(64, 2048)', '(64, 2048)']})"
254+
},
255+
{
256+
"count": 1,
257+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(256, 2048)', '(64, 2048)', '(64, 2048)', '(64, 2048)', '(64, 2048)', '(64, 2048)', '(128, 2048)', '(64, 2048)', '(64, 2048)']})"
258+
},
259+
{
260+
"count": 1,
261+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(256, 768)', '(64, 768)', '(64, 768)', '(64, 768)', '(64, 768)', '(64, 768)', '(64, 768)', '(320, 768)', '(64, 768)']})"
262+
},
263+
{
264+
"count": 1,
265+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(64, 2048)', '(64, 2048)', '(64, 2048)', '(320, 2048)', '(64, 2048)', '(64, 2048)']})"
266+
},
267+
{
268+
"count": 1,
269+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(256, 768)', '(64, 768)', '(64, 768)', '(64, 768)', '(64, 768)', '(64, 768)', '(128, 768)', '(64, 768)', '(64, 768)']})"
270+
},
271+
{
272+
"count": 1,
273+
"inputs": "((), {'B': '(2048, 2048)', 'A_list': ['(256, 2048)', '(64, 2048)', '(64, 2048)', '(64, 2048)', '(64, 2048)', '(64, 2048)', '(64, 2048)', '(320, 2048)', '(64, 2048)']})"
274+
},
275+
{
276+
"count": 1,
277+
"inputs": "((), {'B': '(768, 768)', 'A_list': ['(64, 768)', '(64, 768)', '(64, 768)', '(320, 768)', '(64, 768)', '(64, 768)']})"
278+
}
279+
]
280+
}

tritonbench/data/loader.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import importlib
2+
import os
23
from pathlib import Path
34
from typing import Any
45

56
from tritonbench.utils.env_utils import is_fbcode
67

7-
SUPPORTED_INPUT_OPS = ["highway_self_gating"]
8+
SUPPORTED_INPUT_OPS = ["highway_self_gating", "grouped_gemm"]
89

910
INPUT_CONFIG_DIR = Path(__file__).parent.joinpath("input_configs")
1011
INTERNAL_INPUT_CONFIG_DIR = (
@@ -19,7 +20,15 @@ def get_input_loader(tritonbench_op: Any, op: str, input: str):
1920
hasattr(tritonbench_op, "aten_op_name") or op in SUPPORTED_INPUT_OPS
2021
), f"Unsupported op: {op}. "
2122
op_module = ".".join(tritonbench_op.__module__.split(".")[:-1])
22-
generator_module = importlib.import_module(f"{op_module}.input_loader")
23-
input_iter_getter = generator_module.get_input_iter
24-
input_iter = input_iter_getter(tritonbench_op, op, input)
25-
return input_iter
23+
generator_module = importlib.import_module(op_module)
24+
input_loader_cls = generator_module.InputLoader
25+
if os.path.exists(input):
26+
input_file_path = Path(input)
27+
elif INPUT_CONFIG_DIR.joinpath(input).exists():
28+
input_file_path = INPUT_CONFIG_DIR.joinpath(input)
29+
elif INTERNAL_INPUT_CONFIG_DIR.joinpath(input).exists():
30+
input_file_path = INTERNAL_INPUT_CONFIG_DIR.joinpath(input)
31+
else:
32+
raise RuntimeError(f"Input file {input} does not exist.")
33+
input_loader = input_loader_cls(tritonbench_op, op, input_file_path)
34+
return input_loader.get_input_iter()

tritonbench/operator_loader/aten/input_loader.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import logging
88
import math
99
from collections import Counter, defaultdict
10-
from pathlib import Path
1110
from typing import Any, Callable, Generator
1211

1312
import torch
@@ -164,7 +163,7 @@ def deserialize_args(inps):
164163
return eval(inps.strip().strip("'").strip('"'), global_vals)
165164

166165

167-
class OperatorInputsLoader:
166+
class OperatorInputLoader:
168167
def __init__(self, op_name: str, json_file_path: str):
169168
self.op_name = op_name
170169
self.operator_db = defaultdict(Counter)
@@ -229,5 +228,5 @@ def merge(self, other):
229228
def get_input_iter(tritonbench_op: Any, op: str, input: str) -> Generator:
230229
aten_op_name = tritonbench_op.aten_op_name
231230
input_file_path = INPUT_CONFIG_DIR.joinpath(input)
232-
operator_inputs_loader = OperatorInputsLoader(aten_op_name, input_file_path)
231+
operator_inputs_loader = OperatorInputLoader(aten_op_name, input_file_path)
233232
return operator_inputs_loader.get_input_iter()
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from .operator import Operator
1+
from .input_loader import InputLoader # noqa
2+
from .operator import Operator # noqa
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""
2+
Input loader for Grouped GEMM operator.
3+
"""
4+
5+
from typing import Callable
6+
7+
from tritonbench.operator_loader.aten.input_loader import OperatorInputLoader
8+
9+
10+
class InputLoader(OperatorInputLoader):
11+
def __init__(self, tritonbench_op: str, op_name: str, json_file_path: str):
12+
super().__init__(op_name, json_file_path)
13+
self.op = tritonbench_op
14+
15+
def get_input_iter(
16+
self,
17+
) -> Callable:
18+
shapes = [eval(inp)[1] for inp, _cnt in self.operator_db[self.op_name].items()]
19+
parsed = []
20+
for entry in shapes:
21+
B_shape = (
22+
eval(entry["B"]) if isinstance(entry["B"], str) else tuple(entry["B"])
23+
)
24+
A_shapes = [
25+
eval(a) if isinstance(a, str) else tuple(a) for a in entry["A_list"]
26+
]
27+
parsed.append((A_shapes, B_shape))
28+
29+
# Set shapes on the operator
30+
self.op.external_shapes = parsed
31+
return self.op.get_input_iter

0 commit comments

Comments
 (0)