2121from triton_kernels_benchmark import xetla_kernel
2222from triton_kernels_benchmark import cutlass_kernel
2323
24+ from utils .dpas_layout_analyzer import calculate_optimal_warps_per_cta , calculate_optimal_rep_clusters
25+
2426
2527def get_matmul_autotune_configs () -> List [triton .Config ]:
2628 configs = [
@@ -178,32 +180,20 @@ def get_gluon_matmul_autotune_configs(base_configs_fn: Callable) -> List[triton.
178180 # Append additional meta parameters needed for gluon kernel
179181 # To determine prefetch distance and DPAS layout
180182 {** config .kwargs , 'NUM_STAGES' : config .num_stages , 'NUM_WARPS' : config .num_warps },
181- num_stages = config .num_stages ,
182- num_warps = config .num_warps
183- )
184- for config in base_configs
183+ num_stages = config .num_stages , num_warps = config .num_warps ) for config in base_configs
185184 ]
186185
187186
188187@gluon .constexpr_function
189- def get_dpas_layout (num_warps : ttgl .constexpr ) -> ttgl .constexpr :
190- # TODO: return same DPAS layout as calculated by passes for triton
191- warps_per_cta = [2 , 2 ]
192- if num_warps == 16 :
193- warps_per_cta = [4 , 4 ]
194- if num_warps == 32 :
195- warps_per_cta = [4 , 8 ]
196- elif num_warps == 64 :
197- warps_per_cta = [8 , 8 ]
188+ def get_dpas_layout (num_warps : ttgl .constexpr , m_shape : ttgl .constexpr , n_shape : ttgl .constexpr ,
189+ k_shape : ttgl .constexpr ) -> ttgl .constexpr :
190+ threads_per_warp = 16
191+ warps_per_cta = calculate_optimal_warps_per_cta (num_warps , m_shape , n_shape )
192+
198193 return IntelDPASLayout (
199- repeatCount = 8 ,
200- systolic_depth = 8 ,
201- execution_size = 16 ,
202- ops_per_chan = 2 ,
203- warps_per_cta = warps_per_cta ,
204- rep_cluster = [4 , 2 ],
205- threads_per_warp = 16
206- )
194+ repeatCount = 8 , systolic_depth = 8 , execution_size = 16 , ops_per_chan = 2 , warps_per_cta = warps_per_cta ,
195+ rep_cluster = calculate_optimal_rep_clusters (m_shape , n_shape , k_shape , threads_per_warp ,
196+ warps_per_cta ), threads_per_warp = threads_per_warp )
207197
208198
209199@triton .autotune (
@@ -217,16 +207,14 @@ def gluon_matmul_kernel_dpas_tensor_desc(
217207 # Matrix dimensions
218208 M : ttgl .constexpr , N : ttgl .constexpr , K : ttgl .constexpr ,
219209 # Stride variables
220- stride_am : ttgl .constexpr , stride_ak : ttgl .constexpr ,
221- stride_bk : ttgl .constexpr , stride_bn : ttgl .constexpr ,
210+ stride_am : ttgl .constexpr , stride_ak : ttgl .constexpr , stride_bk : ttgl .constexpr , stride_bn : ttgl .constexpr ,
222211 stride_cm : ttgl .constexpr , stride_cn : ttgl .constexpr ,
223212 # Meta parameters
224213 BLOCK_SIZE_M : ttgl .constexpr , BLOCK_SIZE_N : ttgl .constexpr , BLOCK_SIZE_K : ttgl .constexpr ,
225214 GROUP_SIZE_M : ttgl .constexpr ,
226215 # Gluon meta parameters
227216 NUM_STAGES : ttgl .constexpr , NUM_WARPS : ttgl .constexpr ):
228- layout : ttgl .constexpr = get_dpas_layout (NUM_WARPS )
229-
217+ layout : ttgl .constexpr = get_dpas_layout (NUM_WARPS , BLOCK_SIZE_M , BLOCK_SIZE_N , BLOCK_SIZE_K )
230218
231219 lhs_layout : ttgl .constexpr = ttgl .DotOperandLayout (parent = layout , operand_index = 0 , k_width = 1 )
232220 rhs_layout : ttgl .constexpr = ttgl .DotOperandLayout (parent = layout , operand_index = 1 , k_width = 2 )
@@ -241,19 +229,19 @@ def gluon_matmul_kernel_dpas_tensor_desc(
241229 pid_m = first_pid_m + ((pid % num_pid_in_group ) % group_size_m )
242230 pid_n = (pid % num_pid_in_group ) // group_size_m
243231
244- a_desc = ttgl .intel .xpu .xe .make_tensor_descriptor (a_ptr , (M , K ), (stride_am , stride_ak ), (BLOCK_SIZE_M , BLOCK_SIZE_K ),
245- lhs_layout )
246- b_desc = ttgl .intel .xpu .xe .make_tensor_descriptor (b_ptr , (K , N ), (stride_bk , stride_bn ), (BLOCK_SIZE_K , BLOCK_SIZE_N ),
247- rhs_layout )
248- c_desc = ttgl .intel .xpu .xe .make_tensor_descriptor (c_ptr , (M , N ), (stride_cm , stride_cn ), (BLOCK_SIZE_M , BLOCK_SIZE_N ), layout )
232+ a_desc = ttgl .intel .xpu .xe .make_tensor_descriptor (a_ptr , (M , K ), (stride_am , stride_ak ),
233+ (BLOCK_SIZE_M , BLOCK_SIZE_K ), lhs_layout )
234+ b_desc = ttgl .intel .xpu .xe .make_tensor_descriptor (b_ptr , (K , N ), (stride_bk , stride_bn ),
235+ (BLOCK_SIZE_K , BLOCK_SIZE_N ), rhs_layout )
236+ c_desc = ttgl .intel .xpu .xe .make_tensor_descriptor (c_ptr , (M , N ), (stride_cm , stride_cn ),
237+ (BLOCK_SIZE_M , BLOCK_SIZE_N ), layout )
249238
250239 # Clear accumulator
251240 zero_tensor = ttgl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = ttgl .float32 , layout = layout )
252241 c_desc .store_2d ([pid_m * BLOCK_SIZE_M , pid_n * BLOCK_SIZE_N ], zero_tensor )
253242
254243 accumulator = c_desc .load_2d ([pid_m * BLOCK_SIZE_M , pid_n * BLOCK_SIZE_N ])
255244
256-
257245 # Prefetch first blocks for A and B matrices (pre-loop prefetches)
258246 for i in range (NUM_STAGES ):
259247 if i * BLOCK_SIZE_K < K :
@@ -286,15 +274,15 @@ def gluon_matmul_kernel_dpas_tensor_desc_batched(
286274 # Matrix dimensions
287275 B : ttgl .constexpr , M : ttgl .constexpr , N : ttgl .constexpr , K : ttgl .constexpr ,
288276 # Stride variables
289- stride_az : ttgl .constexpr , stride_am : ttgl .constexpr , stride_ak : ttgl .constexpr ,
290- stride_bz : ttgl .constexpr , stride_bk : ttgl .constexpr , stride_bn : ttgl .constexpr ,
291- stride_cz : ttgl . constexpr , stride_cm : ttgl . constexpr , stride_cn : ttgl .constexpr ,
277+ stride_az : ttgl .constexpr , stride_am : ttgl .constexpr , stride_ak : ttgl .constexpr , stride_bz : ttgl . constexpr ,
278+ stride_bk : ttgl .constexpr , stride_bn : ttgl .constexpr , stride_cz : ttgl . constexpr , stride_cm : ttgl .constexpr ,
279+ stride_cn : ttgl .constexpr ,
292280 # Meta parameters
293281 BLOCK_SIZE_M : ttgl .constexpr , BLOCK_SIZE_N : ttgl .constexpr , BLOCK_SIZE_K : ttgl .constexpr ,
294282 GROUP_SIZE_M : ttgl .constexpr ,
295283 # Gluon meta parameters
296284 NUM_STAGES : ttgl .constexpr , NUM_WARPS : ttgl .constexpr ):
297- layout : ttgl .constexpr = get_dpas_layout (NUM_WARPS )
285+ layout : ttgl .constexpr = get_dpas_layout (NUM_WARPS , BLOCK_SIZE_M , BLOCK_SIZE_N , BLOCK_SIZE_K )
298286
299287 lhs_layout : ttgl .constexpr = ttgl .DotOperandLayout (parent = layout , operand_index = 0 , k_width = 1 )
300288 rhs_layout : ttgl .constexpr = ttgl .DotOperandLayout (parent = layout , operand_index = 1 , k_width = 2 )
@@ -315,18 +303,12 @@ def gluon_matmul_kernel_dpas_tensor_desc_batched(
315303 offset_b = bid .to (ttgl .int64 ) * stride_bz
316304 offset_c = bid .to (ttgl .int64 ) * stride_cz
317305
318- a_desc = ttgl .intel .xpu .xe .make_tensor_descriptor (
319- a_ptr + offset_a , (M , K ), (stride_am , stride_ak ),
320- (BLOCK_SIZE_M , BLOCK_SIZE_K ), lhs_layout
321- )
322- b_desc = ttgl .intel .xpu .xe .make_tensor_descriptor (
323- b_ptr + offset_b , (K , N ), (stride_bk , stride_bn ),
324- (BLOCK_SIZE_K , BLOCK_SIZE_N ), rhs_layout
325- )
326- c_desc = ttgl .intel .xpu .xe .make_tensor_descriptor (
327- c_ptr + offset_c , (M , N ), (stride_cm , stride_cn ),
328- (BLOCK_SIZE_M , BLOCK_SIZE_N ), layout
329- )
306+ a_desc = ttgl .intel .xpu .xe .make_tensor_descriptor (a_ptr + offset_a , (M , K ), (stride_am , stride_ak ),
307+ (BLOCK_SIZE_M , BLOCK_SIZE_K ), lhs_layout )
308+ b_desc = ttgl .intel .xpu .xe .make_tensor_descriptor (b_ptr + offset_b , (K , N ), (stride_bk , stride_bn ),
309+ (BLOCK_SIZE_K , BLOCK_SIZE_N ), rhs_layout )
310+ c_desc = ttgl .intel .xpu .xe .make_tensor_descriptor (c_ptr + offset_c , (M , N ), (stride_cm , stride_cn ),
311+ (BLOCK_SIZE_M , BLOCK_SIZE_N ), layout )
330312
331313 # Clear accumulator
332314 zero_tensor = ttgl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = ttgl .float32 , layout = layout )
@@ -386,20 +368,12 @@ def matmul(
386368 triton .cdiv (M , META ['BLOCK_SIZE_M' ]) * triton .cdiv (N , META ['BLOCK_SIZE_N' ]),
387369 B ,
388370 )
389- matmul_kernel_batched [grid ](
390- a , b , c , #
391- B , M , N , K , #
392- a .stride (0 ), a .stride (a_major ), a .stride (a_minor ), #
393- b .stride (0 ), b .stride (b_minor ), b .stride (b_major ), #
394- c .stride (0 ), c .stride (1 ), c .stride (2 ))
371+ matmul_kernel_batched [grid ](a , b , c , B , M , N , K , a .stride (0 ), a .stride (a_major ), a .stride (a_minor ), b .stride (0 ),
372+ b .stride (b_minor ), b .stride (b_major ), c .stride (0 ), c .stride (1 ), c .stride (2 ))
395373 elif len (a .shape ) == 2 and len (b .shape ) == 2 :
396374 grid = lambda META : (triton .cdiv (M , META ['BLOCK_SIZE_M' ]) * triton .cdiv (N , META ['BLOCK_SIZE_N' ]), )
397- matmul_kernel [grid ](
398- a , b , c , #
399- M , N , K , #
400- a .stride (a_major ), a .stride (a_minor ), #
401- b .stride (b_minor ), b .stride (b_major ), #
402- c .stride (0 ), c .stride (1 ))
375+ matmul_kernel [grid ](a , b , c , M , N , K , a .stride (a_major ), a .stride (a_minor ), b .stride (b_minor ),
376+ b .stride (b_major ), c .stride (0 ), c .stride (1 ))
403377 else :
404378 assert False , 'Input matrixs dimensions mismatch'
405379 return c
@@ -459,7 +433,7 @@ def get_shapes(B, M, N, K, transpose_a, transpose_b):
459433 [4 , 32768 , 4096 , 128 ],
460434 [32 , 4096 , 128 , 4096 ],
461435 [4096 , 8 , 128 , 16384 ],
462- # [4096, 8, 16384, 128], # TODO: mismatches for gluon
436+ [4096 , 8 , 16384 , 128 ],
463437]
464438
465439DEVICE_NAME = torch .xpu .get_device_name ()
@@ -498,13 +472,13 @@ def get_benchmark(
498472 supported_providers = {
499473 'gluon' : 'Gluon' ,
500474 'triton' : 'Triton' ,
501- 'onednn' : 'OneDNN' ,
475+ # 'onednn': 'OneDNN',
502476 }
503477 # use_cutlass
504- if not (transpose_a or transpose_b ):
505- if torch .xpu .get_device_name () != 'Intel(R) Arc(TM) Graphics' :
506- # FIXME: enable cutlass on LNL
507- supported_providers ['cutlass' ] = 'CUTLASS'
478+ # if not (transpose_a or transpose_b):
479+ # if torch.xpu.get_device_name() != 'Intel(R) Arc(TM) Graphics':
480+ # # FIXME: enable cutlass on LNL
481+ # supported_providers['cutlass'] = 'CUTLASS'
508482 providers = benchmark_suite .filter_providers (supported_providers , providers_filter )
509483
510484 # Benchmark Performance
0 commit comments