Skip to content

Commit 84c8ab4

Browse files
[intel] Implement filter_layouts (#3045)
By implementing `filter_layouts`, we can add back layouts from other backends to reduce differences from upstream. Signed-off-by: Whitney Tsang <[email protected]>
1 parent 7eb41bf commit 84c8ab4

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

python/test/unit/language/test_core.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ def is_layout_applicable(layout) -> bool:
221221
common_layouts = [BlockedLayout, SharedLayout]
222222
if layout in common_layouts:
223223
return True
224+
elif isinstance(layout, BlockedLayout) or isinstance(layout, SharedLayout):
225+
return True
224226
elif is_cuda():
225227
mma_layout = layout.parent if isinstance(layout, DotOperandLayout) else layout
226228
if not isinstance(mma_layout, MmaLayout):
@@ -238,6 +240,9 @@ def is_layout_applicable(layout) -> bool:
238240
return isinstance(layout, MfmaLayout)
239241
else:
240242
return False
243+
elif is_xpu():
244+
mma_layout = layout.parent if isinstance(layout, DotOperandLayout) else layout
245+
return isinstance(mma_layout, DpasLayout)
241246
else:
242247
return True
243248

@@ -2692,6 +2697,23 @@ def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_pa
26922697
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
26932698
BlockedLayout([4, 4], [THREADS_PER_WARP // 16, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
26942699
BlockedLayout([1, 2], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
2700+
MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1],
2701+
instr_shape=[16, 8]),
2702+
MmaLayout(version=(2, 0), warps_per_cta=[2, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1],
2703+
instr_shape=[16, 8]),
2704+
MmaLayout(version=(3, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0],
2705+
instr_shape=[16, 16, 16]),
2706+
MmaLayout(version=(3, 0), warps_per_cta=[4, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0],
2707+
instr_shape=[16, 32, 16]),
2708+
MfmaLayout(version=(2, 0), warps_per_cta=[2, 2], instr_shape=[32, 32], is_transposed=False),
2709+
MfmaLayout(version=(2, 0), warps_per_cta=[4, 1], instr_shape=[32, 32], is_transposed=False),
2710+
MfmaLayout(version=(2, 0), warps_per_cta=[1, 4], instr_shape=[32, 32], is_transposed=False),
2711+
MfmaLayout(version=(2, 0), warps_per_cta=[2, 2], instr_shape=[32, 32], is_transposed=True),
2712+
MfmaLayout(version=(2, 0), warps_per_cta=[4, 1], instr_shape=[32, 32], is_transposed=True),
2713+
MfmaLayout(version=(2, 0), warps_per_cta=[1, 4], instr_shape=[32, 32], is_transposed=True),
2714+
WmmaLayout(version=1, warps_per_cta=[2, 2]),
2715+
WmmaLayout(version=1, warps_per_cta=[4, 1]),
2716+
WmmaLayout(version=1, warps_per_cta=[1, 4]),
26952717
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32,
26962718
warps_per_cta=[4, 1], rep_cluster=[1, 1]),
26972719
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=32,
@@ -2887,6 +2909,8 @@ def test_store_op(M, src_layout, device, tmp_path: pathlib.Path):
28872909
# TODO (lixun): Add MfmaLayout
28882910
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
28892911
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
2912+
MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1],
2913+
instr_shape=[16, 8]),
28902914
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32,
28912915
warps_per_cta=[4, 1], rep_cluster=[1, 1])
28922916
]
@@ -5309,6 +5333,9 @@ def kernel(Out):
53095333
# TODO: backend should be tested separately
53105334

53115335
layouts = [
5336+
MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]),
5337+
DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=2),
5338+
DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=1),
53125339
BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
53135340
BlockedLayout([1, 8], [2, THREADS_PER_WARP // 2], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
53145341
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
@@ -5317,6 +5344,15 @@ def kernel(Out):
53175344
BlockedLayout([4, 1], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
53185345
BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
53195346
BlockedLayout([4, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
5347+
DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2),
5348+
DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2),
5349+
DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2),
5350+
DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2),
5351+
DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=8),
5352+
DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8),
5353+
DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=8),
5354+
DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8),
5355+
MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]),
53205356
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32,
53215357
warps_per_cta=[4, 1], rep_cluster=[1, 1])
53225358
]

0 commit comments

Comments
 (0)