@@ -62,33 +62,26 @@ def __str__(self):
6262
6363class BlockedLayout :
6464
65- def __init__ (self , size_per_thread , threads_per_warp , warps_per_cta , order , ctas_per_cga = [1 , 1 ],
66- cta_split_num = [1 , 1 ], cta_order = [0 , 1 ]):
65+ def __init__ (self , size_per_thread , threads_per_warp , warps_per_cta , order ):
6766 self .sz_per_thread = size_per_thread
6867 self .threads_per_warp = threads_per_warp
6968 self .warps_per_cta = warps_per_cta
7069 self .order = order
71- self .ctas_per_cga = ctas_per_cga
72- self .cta_split_num = cta_split_num
73- self .cta_order = cta_order
7470
7571 def __str__ (self ):
76- return f"#{ GPU_DIALECT } .blocked<{{sizePerThread={ self .sz_per_thread } , threadsPerWarp={ self .threads_per_warp } , warpsPerCTA={ self .warps_per_cta } , order={ self .order } , CTAsPerCGA= { self . ctas_per_cga } , CTASplitNum= { self . cta_split_num } , CTAOrder= { self . cta_order } }}>"
72+ return f"#{ GPU_DIALECT } .blocked<{{sizePerThread={ self .sz_per_thread } , threadsPerWarp={ self .threads_per_warp } , warpsPerCTA={ self .warps_per_cta } , order={ self .order } }}>"
7773
7874
7975class SwizzledSharedLayout :
8076
81- def __init__ (self , vec , per_phase , max_phase , order , ctas_per_cga , cta_split_num , cta_order ):
77+ def __init__ (self , vec , per_phase , max_phase , order ):
8278 self .vec = vec
8379 self .per_phase = per_phase
8480 self .max_phase = max_phase
8581 self .order = order
86- self .ctas_per_cga = ctas_per_cga
87- self .cta_split_num = cta_split_num
88- self .cta_order = cta_order
8982
9083 def __str__ (self ):
91- return f"#{ GPU_DIALECT } .swizzled_shared<{{vec={ self .vec } , perPhase={ self .per_phase } , maxPhase={ self .max_phase } , order={ self .order } , CTAsPerCGA= { self . ctas_per_cga } , CTASplitNum= { self . cta_split_num } , CTAOrder= { self . cta_order } }}>"
84+ return f"#{ GPU_DIALECT } .swizzled_shared<{{vec={ self .vec } , perPhase={ self .per_phase } , maxPhase={ self .max_phase } , order={ self .order } }}>"
9285
9386
9487class PaddedSharedLayout :
@@ -172,17 +165,17 @@ def get_reduce_input(dtype_str, shape):
172165
173166
174167scan_layouts = [
175- BlockedLayout ([1 , 4 ], [4 , THREADS_PER_WARP // 4 ], [4 , 1 ], [0 , 1 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
176- BlockedLayout ([1 , 4 ], [8 , THREADS_PER_WARP // 8 ], [4 , 1 ], [0 , 1 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
177- BlockedLayout ([4 , 1 ], [4 , THREADS_PER_WARP // 4 ], [1 , 4 ], [0 , 1 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
178- BlockedLayout ([2 , 2 ], [4 , THREADS_PER_WARP // 4 ], [2 , 2 ], [0 , 1 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
179- BlockedLayout ([2 , 2 ], [8 , THREADS_PER_WARP // 8 ], [2 , 2 ], [0 , 1 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
180- BlockedLayout ([1 , 4 ], [4 , THREADS_PER_WARP // 4 ], [4 , 1 ], [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
181- BlockedLayout ([1 , 4 ], [8 , THREADS_PER_WARP // 8 ], [4 , 1 ], [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
182- BlockedLayout ([4 , 1 ], [4 , THREADS_PER_WARP // 4 ], [1 , 4 ], [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
183- BlockedLayout ([2 , 2 ], [4 , THREADS_PER_WARP // 4 ], [2 , 2 ], [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
184- BlockedLayout ([2 , 2 ], [8 , THREADS_PER_WARP // 8 ], [2 , 2 ], [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
185- BlockedLayout ([1 , 2 ], [1 , THREADS_PER_WARP // 1 ], [1 , 4 ], [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
168+ BlockedLayout ([1 , 4 ], [4 , THREADS_PER_WARP // 4 ], [4 , 1 ], [0 , 1 ]),
169+ BlockedLayout ([1 , 4 ], [8 , THREADS_PER_WARP // 8 ], [4 , 1 ], [0 , 1 ]),
170+ BlockedLayout ([4 , 1 ], [4 , THREADS_PER_WARP // 4 ], [1 , 4 ], [0 , 1 ]),
171+ BlockedLayout ([2 , 2 ], [4 , THREADS_PER_WARP // 4 ], [2 , 2 ], [0 , 1 ]),
172+ BlockedLayout ([2 , 2 ], [8 , THREADS_PER_WARP // 8 ], [2 , 2 ], [0 , 1 ]),
173+ BlockedLayout ([1 , 4 ], [4 , THREADS_PER_WARP // 4 ], [4 , 1 ], [1 , 0 ]),
174+ BlockedLayout ([1 , 4 ], [8 , THREADS_PER_WARP // 8 ], [4 , 1 ], [1 , 0 ]),
175+ BlockedLayout ([4 , 1 ], [4 , THREADS_PER_WARP // 4 ], [1 , 4 ], [1 , 0 ]),
176+ BlockedLayout ([2 , 2 ], [4 , THREADS_PER_WARP // 4 ], [2 , 2 ], [1 , 0 ]),
177+ BlockedLayout ([2 , 2 ], [8 , THREADS_PER_WARP // 8 ], [2 , 2 ], [1 , 0 ]),
178+ BlockedLayout ([1 , 2 ], [1 , THREADS_PER_WARP // 1 ], [1 , 4 ], [1 , 0 ]),
186179]
187180
188181
@@ -254,8 +247,8 @@ def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_pa
254247
255248
256249layouts = [
257- BlockedLayout ([1 , 4 ], [8 , THREADS_PER_WARP // 8 ], [4 , 1 ], [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
258- BlockedLayout ([1 , 4 ], [8 , THREADS_PER_WARP // 8 ], [4 , 1 ], [0 , 1 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
250+ BlockedLayout ([1 , 4 ], [8 , THREADS_PER_WARP // 8 ], [4 , 1 ], [1 , 0 ]),
251+ BlockedLayout ([1 , 4 ], [8 , THREADS_PER_WARP // 8 ], [4 , 1 ], [0 , 1 ]),
259252 DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 8 , ops_per_chan = 1 , threads_per_warp = 32 ,
260253 warps_per_cta = [4 , 1 ], rep_cluster = [1 , 1 ]),
261254 DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 16 , ops_per_chan = 2 , threads_per_warp = 32 ,
@@ -305,8 +298,8 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov
305298 store_range = "%7" if axis == 0 else "%1"
306299 warps = warps_per_cta (src_layout , [M , N ])
307300 num_warps = int (np .prod (warps ))
308- blocked = BlockedLayout ([1 , 1 ], [32 , THREADS_PER_WARP // 32 ], [4 , num_warps // 4 ], [0 , 1 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] )
309- one_d_layout = BlockedLayout ([1 ], [THREADS_PER_WARP ], [num_warps ], [0 ], [ 1 ], [ 1 ], [ 0 ] )
301+ blocked = BlockedLayout ([1 , 1 ], [32 , THREADS_PER_WARP // 32 ], [4 , num_warps // 4 ], [0 , 1 ])
302+ one_d_layout = BlockedLayout ([1 ], [THREADS_PER_WARP ], [num_warps ], [0 ])
310303
311304 expanded_shape = f"1x{ N } " if axis == 0 else f"{ M } x1"
312305 other_axis = 1 - axis
@@ -397,8 +390,8 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov
397390
398391
399392layouts = [
400- BlockedLayout ([1 , 4 ], [1 , THREADS_PER_WARP ], [4 , 1 ], [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
401- BlockedLayout ([1 , 4 ], [1 , THREADS_PER_WARP ], [2 , 2 ], [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
393+ BlockedLayout ([1 , 4 ], [1 , THREADS_PER_WARP ], [4 , 1 ], [1 , 0 ]),
394+ BlockedLayout ([1 , 4 ], [1 , THREADS_PER_WARP ], [2 , 2 ], [1 , 0 ]),
402395 DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 8 , ops_per_chan = 1 , threads_per_warp = 32 ,
403396 warps_per_cta = [4 , 1 ], rep_cluster = [1 , 1 ]),
404397]
@@ -443,8 +436,8 @@ def test_store_op(M, src_layout, device, tmp_path: pathlib.Path):
443436
444437
445438layouts = [
446- BlockedLayout ([1 , 4 ], [1 , THREADS_PER_WARP ], [4 , 1 ], [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
447- BlockedLayout ([1 , 4 ], [1 , THREADS_PER_WARP ], [2 , 2 ], [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
439+ BlockedLayout ([1 , 4 ], [1 , THREADS_PER_WARP ], [4 , 1 ], [1 , 0 ]),
440+ BlockedLayout ([1 , 4 ], [1 , THREADS_PER_WARP ], [2 , 2 ], [1 , 0 ]),
448441 DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 8 , ops_per_chan = 1 , threads_per_warp = 32 ,
449442 warps_per_cta = [4 , 1 ], rep_cluster = [1 , 1 ])
450443]
@@ -532,10 +525,10 @@ def test_convert1d_bool(M, src_layout, dst_layout, src_dim, dst_dim, device, tmp
532525
533526
534527layouts = [
535- BlockedLayout ([1 , 4 ], [1 , THREADS_PER_WARP ], [4 , 1 ], [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
536- BlockedLayout ([1 , 4 ], [1 , THREADS_PER_WARP ], [2 , 2 ], [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
537- BlockedLayout ([1 , 4 ], [THREADS_PER_WARP // 32 , 32 ], [1 , 4 ], [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
538- BlockedLayout ([1 , 4 ], [8 , THREADS_PER_WARP // 8 ], [2 , 2 ], [0 , 1 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] )
528+ BlockedLayout ([1 , 4 ], [1 , THREADS_PER_WARP ], [4 , 1 ], [1 , 0 ]),
529+ BlockedLayout ([1 , 4 ], [1 , THREADS_PER_WARP ], [2 , 2 ], [1 , 0 ]),
530+ BlockedLayout ([1 , 4 ], [THREADS_PER_WARP // 32 , 32 ], [1 , 4 ], [1 , 0 ]),
531+ BlockedLayout ([1 , 4 ], [8 , THREADS_PER_WARP // 8 ], [2 , 2 ], [0 , 1 ])
539532]
540533
541534
@@ -611,8 +604,8 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis, tmp_path: pathli
611604# TODO: backend should be tested separately
612605
613606layouts = [
614- BlockedLayout ([1 , 1 ], [THREADS_PER_WARP , 1 ], [2 , 2 ], [0 , 1 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
615- BlockedLayout ([1 , 16 ], [8 , THREADS_PER_WARP // 8 ], [4 , 1 ], [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
607+ BlockedLayout ([1 , 1 ], [THREADS_PER_WARP , 1 ], [2 , 2 ], [0 , 1 ]),
608+ BlockedLayout ([1 , 16 ], [8 , THREADS_PER_WARP // 8 ], [4 , 1 ], [1 , 0 ]),
616609 DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 8 , ops_per_chan = 1 , threads_per_warp = 32 ,
617610 warps_per_cta = [4 , 1 ], rep_cluster = [1 , 1 ]),
618611 DpasLayout (repeatCount = 2 , systolic_depth = 8 , execution_size = 8 , ops_per_chan = 1 , threads_per_warp = 32 ,
@@ -621,10 +614,10 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis, tmp_path: pathli
621614
622615intermediate_layouts = [
623616 None ,
624- SwizzledSharedLayout (1 , 1 , 1 , [0 , 1 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
625- SwizzledSharedLayout (1 , 1 , 1 , [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
626- SwizzledSharedLayout (4 , 2 , 4 , [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
627- SwizzledSharedLayout (2 , 2 , 4 , [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
617+ SwizzledSharedLayout (1 , 1 , 1 , [0 , 1 ]),
618+ SwizzledSharedLayout (1 , 1 , 1 , [1 , 0 ]),
619+ SwizzledSharedLayout (4 , 2 , 4 , [1 , 0 ]),
620+ SwizzledSharedLayout (2 , 2 , 4 , [1 , 0 ]),
628621]
629622
630623
@@ -736,15 +729,15 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t
736729
737730
738731layouts_3d = [
739- BlockedLayout ([4 , 4 , 1 ], [1 , 8 , THREADS_PER_WARP // 8 ], [2 , 2 , 1 ], [2 , 1 , 0 ], [ 1 , 1 , 1 ], [ 1 , 1 , 1 ], [ 0 , 1 , 2 ] ),
740- BlockedLayout ([1 , 1 , 4 ], [8 , THREADS_PER_WARP // 8 , 1 ], [2 , 1 , 2 ], [1 , 2 , 0 ], [ 1 , 1 , 1 ], [ 1 , 1 , 1 ], [ 0 , 1 , 2 ] ),
732+ BlockedLayout ([4 , 4 , 1 ], [1 , 8 , THREADS_PER_WARP // 8 ], [2 , 2 , 1 ], [2 , 1 , 0 ]),
733+ BlockedLayout ([1 , 1 , 4 ], [8 , THREADS_PER_WARP // 8 , 1 ], [2 , 1 , 2 ], [1 , 2 , 0 ]),
741734]
742735
743736shared_layouts_3d = [
744- SwizzledSharedLayout (1 , 1 , 1 , [2 , 1 , 0 ], [ 1 , 1 , 1 ], [ 1 , 1 , 1 ], [ 0 , 1 , 2 ] ),
745- SwizzledSharedLayout (4 , 2 , 4 , [1 , 2 , 0 ], [ 1 , 1 , 1 ], [ 1 , 1 , 1 ], [ 0 , 1 , 2 ] ),
746- SwizzledSharedLayout (8 , 2 , 4 , [0 , 2 , 1 ], [ 1 , 1 , 1 ], [ 1 , 1 , 1 ], [ 0 , 1 , 2 ] ),
747- SwizzledSharedLayout (4 , 2 , 1 , [2 , 0 , 1 ], [ 1 , 1 , 1 ], [ 1 , 1 , 1 ], [ 0 , 1 , 2 ] ),
737+ SwizzledSharedLayout (1 , 1 , 1 , [2 , 1 , 0 ]),
738+ SwizzledSharedLayout (4 , 2 , 4 , [1 , 2 , 0 ]),
739+ SwizzledSharedLayout (8 , 2 , 4 , [0 , 2 , 1 ]),
740+ SwizzledSharedLayout (4 , 2 , 1 , [2 , 0 , 1 ]),
748741]
749742
750743
@@ -841,9 +834,9 @@ def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path:
841834]
842835
843836shared_layouts = [
844- SwizzledSharedLayout (4 , 2 , 4 , [0 , 1 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
845- SwizzledSharedLayout (8 , 1 , 8 , [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
846- SwizzledSharedLayout (16 , 1 , 16 , [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
837+ SwizzledSharedLayout (4 , 2 , 4 , [0 , 1 ]),
838+ SwizzledSharedLayout (8 , 1 , 8 , [1 , 0 ]),
839+ SwizzledSharedLayout (16 , 1 , 16 , [1 , 0 ]),
847840]
848841
849842
@@ -855,7 +848,7 @@ def test_split_subview(M, N, M_tile_size, N_tile_size, device, tmp_path: pathlib
855848 num_repeats_N = triton .cdiv (N , N_tile_size )
856849
857850 ir = f"""
858- #blocked = #ttg.blocked<{{sizePerThread=[1, 8], threadsPerWarp=[{ num_rows_per_warp } , 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1] }}>
851+ #blocked = #ttg.blocked<{{sizePerThread=[1, 8], threadsPerWarp=[{ num_rows_per_warp } , 4], warpsPerCTA=[4, 1], order=[1, 0]}}>
859852 #shared = #ttg.swizzled_shared<{{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}}>
860853 #smem = #ttg.shared_memory
861854
@@ -989,7 +982,7 @@ def test_local_load_store_dot(M, N, dtype, dist_layout, shared_layout, device, t
989982]
990983
991984shared_layouts = [
992- SwizzledSharedLayout (8 , 1 , 1 , [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 0 , 1 ] ),
985+ SwizzledSharedLayout (8 , 1 , 1 , [1 , 0 ]),
993986]
994987
995988
0 commit comments