@@ -221,6 +221,8 @@ def is_layout_applicable(layout) -> bool:
221221 common_layouts = [BlockedLayout , SharedLayout ]
222222 if layout in common_layouts :
223223 return True
224+ elif isinstance (layout , BlockedLayout ) or isinstance (layout , SharedLayout ):
225+ return True
224226 elif is_cuda ():
225227 mma_layout = layout .parent if isinstance (layout , DotOperandLayout ) else layout
226228 if not isinstance (mma_layout , MmaLayout ):
@@ -238,6 +240,9 @@ def is_layout_applicable(layout) -> bool:
238240 return isinstance (layout , MfmaLayout )
239241 else :
240242 return False
243+ elif is_xpu ():
244+ mma_layout = layout .parent if isinstance (layout , DotOperandLayout ) else layout
245+ return isinstance (mma_layout , DpasLayout )
241246 else :
242247 return True
243248
@@ -2692,6 +2697,23 @@ def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_pa
26922697 BlockedLayout ([1 , 4 ], [8 , THREADS_PER_WARP // 8 ], [2 , 2 ], [0 , 1 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
26932698 BlockedLayout ([4 , 4 ], [THREADS_PER_WARP // 16 , 16 ], [4 , 1 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
26942699 BlockedLayout ([1 , 2 ], [4 , THREADS_PER_WARP // 4 ], [4 , 1 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
2700+ MmaLayout (version = (2 , 0 ), warps_per_cta = [4 , 1 ], ctas_per_cga = [1 , 1 ], cta_split_num = [1 , 1 ], cta_order = [0 , 1 ],
2701+ instr_shape = [16 , 8 ]),
2702+ MmaLayout (version = (2 , 0 ), warps_per_cta = [2 , 2 ], ctas_per_cga = [1 , 1 ], cta_split_num = [1 , 1 ], cta_order = [0 , 1 ],
2703+ instr_shape = [16 , 8 ]),
2704+ MmaLayout (version = (3 , 0 ), warps_per_cta = [4 , 1 ], ctas_per_cga = [1 , 1 ], cta_split_num = [1 , 1 ], cta_order = [1 , 0 ],
2705+ instr_shape = [16 , 16 , 16 ]),
2706+ MmaLayout (version = (3 , 0 ), warps_per_cta = [4 , 2 ], ctas_per_cga = [1 , 1 ], cta_split_num = [1 , 1 ], cta_order = [1 , 0 ],
2707+ instr_shape = [16 , 32 , 16 ]),
2708+ MfmaLayout (version = (2 , 0 ), warps_per_cta = [2 , 2 ], instr_shape = [32 , 32 ], is_transposed = False ),
2709+ MfmaLayout (version = (2 , 0 ), warps_per_cta = [4 , 1 ], instr_shape = [32 , 32 ], is_transposed = False ),
2710+ MfmaLayout (version = (2 , 0 ), warps_per_cta = [1 , 4 ], instr_shape = [32 , 32 ], is_transposed = False ),
2711+ MfmaLayout (version = (2 , 0 ), warps_per_cta = [2 , 2 ], instr_shape = [32 , 32 ], is_transposed = True ),
2712+ MfmaLayout (version = (2 , 0 ), warps_per_cta = [4 , 1 ], instr_shape = [32 , 32 ], is_transposed = True ),
2713+ MfmaLayout (version = (2 , 0 ), warps_per_cta = [1 , 4 ], instr_shape = [32 , 32 ], is_transposed = True ),
2714+ WmmaLayout (version = 1 , warps_per_cta = [2 , 2 ]),
2715+ WmmaLayout (version = 1 , warps_per_cta = [4 , 1 ]),
2716+ WmmaLayout (version = 1 , warps_per_cta = [1 , 4 ]),
26952717 DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 8 , ops_per_chan = 1 , threads_per_warp = 32 ,
26962718 warps_per_cta = [4 , 1 ], rep_cluster = [1 , 1 ]),
26972719 DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 16 , ops_per_chan = 2 , threads_per_warp = 32 ,
@@ -2887,6 +2909,8 @@ def test_store_op(M, src_layout, device, tmp_path: pathlib.Path):
28872909 # TODO (lixun): Add MfmaLayout
28882910 BlockedLayout ([1 , 4 ], [1 , THREADS_PER_WARP ], [4 , 1 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
28892911 BlockedLayout ([1 , 4 ], [1 , THREADS_PER_WARP ], [2 , 2 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
2912+ MmaLayout (version = (2 , 0 ), warps_per_cta = [4 , 1 ], ctas_per_cga = [1 , 1 ], cta_split_num = [1 , 1 ], cta_order = [0 , 1 ],
2913+ instr_shape = [16 , 8 ]),
28902914 DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 8 , ops_per_chan = 1 , threads_per_warp = 32 ,
28912915 warps_per_cta = [4 , 1 ], rep_cluster = [1 , 1 ])
28922916]
@@ -5309,6 +5333,9 @@ def kernel(Out):
53095333# TODO: backend should be tested separately
53105334
53115335layouts = [
5336+ MmaLayout ([3 , 0 ], [4 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 32 , 16 ]),
5337+ DotOperandLayout (parent = MmaLayout ([3 , 0 ], [4 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 32 , 16 ]), op_idx = 0 , k_width = 2 ),
5338+ DotOperandLayout (parent = MmaLayout ([3 , 0 ], [4 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 32 , 16 ]), op_idx = 0 , k_width = 1 ),
53125339 BlockedLayout ([1 , 16 ], [8 , THREADS_PER_WARP // 8 ], [4 , 1 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
53135340 BlockedLayout ([1 , 8 ], [2 , THREADS_PER_WARP // 2 ], [4 , 1 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
53145341 BlockedLayout ([1 , 4 ], [4 , THREADS_PER_WARP // 4 ], [2 , 2 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
@@ -5317,6 +5344,15 @@ def kernel(Out):
53175344 BlockedLayout ([4 , 1 ], [8 , THREADS_PER_WARP // 8 ], [2 , 2 ], [0 , 1 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
53185345 BlockedLayout ([1 , 1 ], [THREADS_PER_WARP , 1 ], [2 , 2 ], [0 , 1 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
53195346 BlockedLayout ([4 , 4 ], [1 , THREADS_PER_WARP ], [4 , 1 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
5347+ DotOperandLayout (parent = MmaLayout ([2 , 0 ], [4 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 8 ]), op_idx = 0 , k_width = 2 ),
5348+ DotOperandLayout (parent = MmaLayout ([2 , 0 ], [4 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 8 ]), op_idx = 1 , k_width = 2 ),
5349+ DotOperandLayout (parent = MmaLayout ([2 , 0 ], [2 , 2 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 8 ]), op_idx = 0 , k_width = 2 ),
5350+ DotOperandLayout (parent = MmaLayout ([2 , 0 ], [2 , 2 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 8 ]), op_idx = 1 , k_width = 2 ),
5351+ DotOperandLayout (parent = MmaLayout ([2 , 0 ], [4 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 8 ]), op_idx = 0 , k_width = 8 ),
5352+ DotOperandLayout (parent = MmaLayout ([2 , 0 ], [4 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 8 ]), op_idx = 1 , k_width = 8 ),
5353+ DotOperandLayout (parent = MmaLayout ([2 , 0 ], [2 , 2 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 8 ]), op_idx = 0 , k_width = 8 ),
5354+ DotOperandLayout (parent = MmaLayout ([2 , 0 ], [2 , 2 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 8 ]), op_idx = 1 , k_width = 8 ),
5355+ MmaLayout ([2 , 0 ], [4 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 8 ]),
53205356 DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 8 , ops_per_chan = 1 , threads_per_warp = 32 ,
53215357 warps_per_cta = [4 , 1 ], rep_cluster = [1 , 1 ])
53225358]
0 commit comments