@@ -162,6 +162,16 @@ def __str__(self):
162162 return f"#{ GPU_DIALECT } .dot_op<{{parent={ self .parent } , opIdx={ self .op_idx } , kWidth={ self .k_width } }}>"
163163
164164
165+ class SliceLayout :
166+
167+ def __init__ (self , dim , parent ):
168+ self .dim = dim
169+ self .parent = parent
170+
171+ def __str__ (self ):
172+ return f"#{ GPU_DIALECT } .slice<{{dim = { self .dim } , parent = { self .parent } }}>"
173+
174+
165175class BlockedLayout :
166176
167177 def __init__ (self , size_per_thread , threads_per_warp , warps_per_cta , order , ctas_per_cga , cta_split_num , cta_order ):
@@ -199,6 +209,8 @@ def is_layout_applicable(layout) -> bool:
199209 common_layouts = [BlockedLayout , SharedLayout ]
200210 if layout in common_layouts :
201211 return True
212+ elif isinstance (layout , SliceLayout ):
213+ return is_layout_applicable (layout .parent )
202214 elif is_cuda ():
203215 mma_layout = layout .parent if isinstance (layout , DotOperandLayout ) else layout
204216 if not isinstance (mma_layout , MmaLayout ):
@@ -2850,8 +2862,11 @@ def test_store_op(M, src_layout, device, tmp_path: pathlib.Path):
28502862 # TODO (lixun): Add MfmaLayout
28512863 BlockedLayout ([1 , 4 ], [1 , THREADS_PER_WARP ], [4 , 1 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
28522864 BlockedLayout ([1 , 4 ], [1 , THREADS_PER_WARP ], [2 , 2 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
2865+ MmaLayout ([3 , 0 ], [4 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 32 , 16 ]),
28532866 MmaLayout (version = (2 , 0 ), warps_per_cta = [4 , 1 ], ctas_per_cga = [1 , 1 ], cta_split_num = [1 , 1 ], cta_order = [0 , 1 ],
2854- instr_shape = [16 , 8 ])
2867+ instr_shape = [16 , 8 ]),
2868+ DotOperandLayout (parent = MmaLayout ([3 , 0 ], [4 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 32 , 16 ]), op_idx = 0 , k_width = 2 ),
2869+ DotOperandLayout (parent = MmaLayout ([2 , 0 ], [2 , 2 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 8 ]), op_idx = 0 , k_width = 2 ),
28552870]
28562871
28572872
@@ -5281,17 +5296,12 @@ def kernel(Out):
52815296# TODO: backend should be tested separately
52825297
52835298layouts = [
5299+ BlockedLayout ([1 , 1 ], [THREADS_PER_WARP , 1 ], [2 , 2 ], [0 , 1 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
5300+ BlockedLayout ([1 , 16 ], [8 , THREADS_PER_WARP // 8 ], [4 , 1 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
52845301 MmaLayout ([3 , 0 ], [4 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 32 , 16 ]),
52855302 DotOperandLayout (parent = MmaLayout ([3 , 0 ], [4 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 32 , 16 ]), op_idx = 0 , k_width = 2 ),
52865303 DotOperandLayout (parent = MmaLayout ([3 , 0 ], [4 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 32 , 16 ]), op_idx = 0 , k_width = 1 ),
5287- BlockedLayout ([1 , 16 ], [8 , THREADS_PER_WARP // 8 ], [4 , 1 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
5288- BlockedLayout ([1 , 8 ], [2 , THREADS_PER_WARP // 2 ], [4 , 1 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
5289- BlockedLayout ([1 , 4 ], [4 , THREADS_PER_WARP // 4 ], [2 , 2 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
5290- BlockedLayout ([1 , 1 ], [1 , THREADS_PER_WARP ], [2 , 2 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
5291- BlockedLayout ([8 , 1 ], [16 , THREADS_PER_WARP // 16 ], [1 , 4 ], [0 , 1 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
5292- BlockedLayout ([4 , 1 ], [8 , THREADS_PER_WARP // 8 ], [2 , 2 ], [0 , 1 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
5293- BlockedLayout ([1 , 1 ], [THREADS_PER_WARP , 1 ], [2 , 2 ], [0 , 1 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
5294- BlockedLayout ([4 , 4 ], [1 , THREADS_PER_WARP ], [4 , 1 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
5304+ MmaLayout ([2 , 0 ], [4 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 8 ]),
52955305 DotOperandLayout (parent = MmaLayout ([2 , 0 ], [4 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 8 ]), op_idx = 0 , k_width = 2 ),
52965306 DotOperandLayout (parent = MmaLayout ([2 , 0 ], [4 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 8 ]), op_idx = 1 , k_width = 2 ),
52975307 DotOperandLayout (parent = MmaLayout ([2 , 0 ], [2 , 2 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 8 ]), op_idx = 0 , k_width = 2 ),
@@ -5300,7 +5310,13 @@ def kernel(Out):
53005310 DotOperandLayout (parent = MmaLayout ([2 , 0 ], [4 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 8 ]), op_idx = 1 , k_width = 8 ),
53015311 DotOperandLayout (parent = MmaLayout ([2 , 0 ], [2 , 2 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 8 ]), op_idx = 0 , k_width = 8 ),
53025312 DotOperandLayout (parent = MmaLayout ([2 , 0 ], [2 , 2 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 8 ]), op_idx = 1 , k_width = 8 ),
5303- MmaLayout ([2 , 0 ], [4 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 0 ], [16 , 8 ]),
5313+ SliceLayout (
5314+ dim = 1 ,
5315+ parent = DotOperandLayout (parent = MmaLayout ([3 , 0 ], [4 , 1 , 1 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [2 , 1 , 0 ], [16 , 32 , 16 ]),
5316+ op_idx = 0 , k_width = 2 )),
5317+ SliceLayout (
5318+ dim = 1 , parent = DotOperandLayout (parent = MmaLayout ([2 , 0 ], [4 , 1 , 1 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [2 , 1 , 0 ], [1 , 16 , 8 ]),
5319+ op_idx = 1 , k_width = 2 )),
53045320]
53055321
53065322intermediate_layouts = [
0 commit comments