@@ -49,8 +49,8 @@ class Experiment:
49
49
50
50
51
51
def get_configs () -> List [ExperimentConfig ]:
52
- input_shapes = [(2 ** 8 , 4096 ), ( 2 ** 12 , 4096 ), ( 2 ** 16 , 4096 )]
53
- n_groups_list = [4 , 8 , 16 ]
52
+ input_shapes = [(16640 , 5120 )] # (Mg, K)
53
+ n_groups_list = [16 , 128 ]
54
54
high_precision_dtypes = [torch .bfloat16 ]
55
55
configs = []
56
56
for input_shape , n_groups , high_precision_dtype in itertools .product (
@@ -129,6 +129,7 @@ def run_triton(
129
129
130
130
# bench torch
131
131
compiled_run_torch = torch .compile (run_torch )
132
+ warmup (compiled_run_torch , input_row_major , input_col_major , offs )
132
133
torch_time_us = benchmark_cuda_function_in_microseconds (
133
134
compiled_run_torch , input_row_major , input_col_major , offs
134
135
)
@@ -152,6 +153,7 @@ def print_results(experiments: List[Experiment]):
152
153
"high_precision_dtype" ,
153
154
"torch_time_us" ,
154
155
"triton_time_us" ,
156
+ "triton_speedup" ,
155
157
]
156
158
rows = []
157
159
for experiment in experiments :
@@ -165,6 +167,7 @@ def print_results(experiments: List[Experiment]):
165
167
experiment .config .high_precision_dtype ,
166
168
experiment .result .torch_time_us ,
167
169
experiment .result .triton_time_us ,
170
+ f"{ experiment .result .torch_time_us / experiment .result .triton_time_us :.2f} x" ,
168
171
]
169
172
)
170
173
print (tabulate (rows , headers = headers ))
0 commit comments