@@ -179,17 +179,20 @@ def __str__(self):
179179
180180class SharedLayout :
181181
182- def __init__ (self , vec , per_phase , max_phase , order , ctas_per_cga , cta_split_num , cta_order ):
182+ def __init__ (self , vec , per_phase , max_phase , order , ctas_per_cga , cta_split_num , cta_order ,
183+ has_leading_offset = False ):
183184 self .vec = vec
184185 self .per_phase = per_phase
185186 self .max_phase = max_phase
186187 self .order = order
187188 self .ctas_per_cga = ctas_per_cga
188189 self .cta_split_num = cta_split_num
189190 self .cta_order = cta_order
191+ self .has_leading_offset = has_leading_offset
190192
191193 def __str__ (self ):
192- return f"#{ GPU_DIALECT } .shared<{{vec={ self .vec } , perPhase={ self .per_phase } , maxPhase={ self .max_phase } , order={ self .order } , CTAsPerCGA={ self .ctas_per_cga } , CTASplitNum={ self .cta_split_num } , CTAOrder={ self .cta_order } }}>"
194+ has_leading_offset_str = "true" if self .has_leading_offset else "false"
195+ return f"#{ GPU_DIALECT } .shared<{{vec={ self .vec } , perPhase={ self .per_phase } , maxPhase={ self .max_phase } , order={ self .order } , CTAsPerCGA={ self .ctas_per_cga } , CTASplitNum={ self .cta_split_num } , CTAOrder={ self .cta_order } , hasLeadingOffset={ has_leading_offset_str } }}>"
193196
194197
195198def is_layout_applicable (layout ) -> bool :
@@ -5418,7 +5421,7 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t
54185421 k_width = 1 ),
54195422]
54205423
5421- shared_layout_3d = [
5424+ shared_layouts_3d = [
54225425 SharedLayout (1 , 1 , 1 , [2 , 1 , 0 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
54235426 SharedLayout (4 , 2 , 4 , [1 , 2 , 0 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
54245427 SharedLayout (8 , 2 , 4 , [0 , 2 , 1 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
@@ -5427,8 +5430,8 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t
54275430
54285431
54295432@pytest .mark .parametrize ("M, N, K" , [[8 , 16 , 32 ]])
5430- @pytest .mark .parametrize ("shared_layout" , shared_layout_3d )
5431- @pytest .mark .parametrize ("dist_layout" , layouts_3d )
5433+ @pytest .mark .parametrize ("shared_layout" , shared_layouts_3d )
5434+ @pytest .mark .parametrize ("dist_layout" , filter_layouts ( layouts_3d ) )
54325435def test_local_load_store (M , N , K , dist_layout , shared_layout , device , tmp_path : pathlib .Path ):
54335436 layouts = f"""
54345437 #dist = { dist_layout }
@@ -5500,6 +5503,72 @@ def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path:
55005503 assert torch .equal (z , x )
55015504
55025505
5506+ mma_layouts = [
5507+ MmaLayout ((2 , 0 ), [1 , 4 ], [1 , 1 ], [1 , 1 ], [0 , 1 ], [16 , 8 ]),
5508+ MmaLayout ((3 , 0 ), [4 , 1 ], [1 , 1 ], [1 , 1 ], [0 , 1 ], [16 , 128 , 16 ]), # simple 4 warps case
5509+ MmaLayout ((3 , 0 ), [8 , 1 ], [1 , 1 ], [1 , 1 ], [0 , 1 ], [16 , 128 , 16 ]), # simple 8 warps case
5510+ MmaLayout ((3 , 0 ), [4 , 2 ], [1 , 1 ], [1 , 1 ], [0 , 1 ], [16 , 128 , 16 ]), # multiple warps on the row
5511+ MmaLayout ((3 , 0 ), [4 , 2 ], [1 , 1 ], [1 , 1 ], [0 , 1 ], [16 , 64 , 16 ]), # small instrN
5512+ MmaLayout ((3 , 0 ), [8 , 4 ], [1 , 1 ], [1 , 1 ], [0 , 1 ], [16 , 64 , 16 ]), # large number of warps
5513+ ]
5514+
5515+ shared_layouts = [
5516+ SharedLayout (8 , 1 , 1 , [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
5517+ SharedLayout (8 , 2 , 4 , [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ], has_leading_offset = True ), # small contiguous bytes
5518+ SharedLayout (8 , 1 , 8 , [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ], has_leading_offset = True ), # maximum contiguous bytes
5519+ ]
5520+
5521+
5522+ @pytest .mark .parametrize ("M, N" , [[128 , 128 ]])
5523+ @pytest .mark .parametrize ("mma_layout" , filter_layouts (mma_layouts ))
5524+ @pytest .mark .parametrize ("shared_layout" , shared_layouts )
5525+ def test_local_load_store_mma (M , N , mma_layout , shared_layout , device , tmp_path : pathlib .Path ):
5526+ num_warps = np .prod (mma_layout .warps_per_cta )
5527+
5528+ layouts = f"""
5529+ #dist = { mma_layout }
5530+ #shared = { shared_layout }
5531+ #smem = #ttg.shared_memory
5532+ """
5533+ ir = layouts + f"""
5534+ module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = { num_warps } : i32, "ttg.threads-per-warp" = { THREADS_PER_WARP } : i32}} {{
5535+ tt.func public @kernel(%arg0: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{
5536+ %cst = arith.constant dense<{ N } > : tensor<{ M } x1xi32, #dist>
5537+ %0 = tt.make_range {{end = { M } : i32, start = 0 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #dist}}>>
5538+ %1 = tt.make_range {{end = { N } : i32, start = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #dist}}>>
5539+ %2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<{ M } x{ N } x!tt.ptr<f16>, #dist>
5540+ %3 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<{ M } x{ N } x!tt.ptr<f16>, #dist>
5541+ %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<{ M } x1xi32, #dist>
5542+ %5 = arith.muli %4, %cst : tensor<{ M } x1xi32, #dist>
5543+ %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #dist}}>> -> tensor<1x{ N } xi32, #dist>
5544+ %7 = tt.broadcast %6 : tensor<1x{ N } xi32, #dist> -> tensor<{ M } x{ N } xi32, #dist>
5545+ %8 = tt.broadcast %5 : tensor<{ M } x1xi32, #dist> -> tensor<{ M } x{ N } xi32, #dist>
5546+ %9 = arith.addi %8, %7 : tensor<{ M } x{ N } xi32, #dist>
5547+ %10 = tt.addptr %2, %9 : tensor<{ M } x{ N } x!tt.ptr<f16>, #dist>, tensor<{ M } x{ N } xi32, #dist>
5548+ %11 = tt.load %10 : tensor<{ M } x{ N } x!tt.ptr<f16>, #dist>
5549+ %12 = ttg.local_alloc %11 : (tensor<{ M } x{ N } xf16, #dist>) -> !ttg.memdesc<{ M } x{ N } xf16, #shared, #smem>
5550+ %13 = ttg.local_load %12 : !ttg.memdesc<{ M } x{ N } xf16, #shared, #smem> -> tensor<{ M } x{ N } xf16, #dist>
5551+ %14 = tt.addptr %3, %9 : tensor<{ M } x{ N } x!tt.ptr<f16>, #dist>, tensor<{ M } x{ N } xi32, #dist>
5552+ tt.store %14, %13 : tensor<{ M } x{ N } x!tt.ptr<f16>, #dist>
5553+ tt.return
5554+ }}
5555+ }}
5556+ """
5557+
5558+ x = torch .arange (0 , M * N , device = device , dtype = torch .float16 ).reshape (M , N )
5559+ z = torch .empty_like (x , device = device )
5560+
5561+ temp_file = tmp_path / "test_local_load_store_mma.ttgir"
5562+ temp_file .write_text (ir )
5563+ kernel = triton .compile (str (temp_file ))
5564+
5565+ kernel [(1 , 1 , 1 )](x , z )
5566+ assert torch .equal (z , x )
5567+
5568+ if shared_layout .has_leading_offset == "true" and mma_layout .version [0 ] >= 3 :
5569+ assert "stmatrix" in kernel .asm ["ptx" ]
5570+
5571+
55035572mma_pairs = [
55045573 [
55055574 MmaLayout ((2 , 0 ), [1 , 4 ], [1 , 1 ], [1 , 1 ], [0 , 1 ], [16 , 8 ]),
@@ -5546,18 +5615,10 @@ def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path:
55465615
55475616@pytest .mark .parametrize ("M, N" , [[64 , 1 ], [1 , 64 ], [64 , 64 ], [128 , 128 ], [256 , 256 ]])
55485617@pytest .mark .parametrize ("dtype" , ['float16' ])
5549- @pytest .mark .parametrize ("mma_pair" , mma_pairs )
5550- def test_convertmma2mma (M , N , mma_pair , dtype , device , tmp_path : pathlib .Path ):
5551- if is_hip ():
5552- pytest .skip ("test_mma2mma is not supported in HIP" )
5553-
5618+ @pytest .mark .parametrize ("mma_pair" , filter_layouts (mma_pairs ))
5619+ def test_convert_mma2mma (M , N , mma_pair , dtype , device , tmp_path : pathlib .Path ):
55545620 src_layout , _ = mma_pair
5555- if is_cuda ():
5556- cc = torch .cuda .get_device_capability ()
5557- if cc [0 ] < 9 and src_layout .version [0 ] >= 3 :
5558- pytest .skip ("Skip testing MMAv3 on devices with CC < 9" )
5559-
5560- num_warps = np .cumprod (src_layout .warps_per_cta )[- 1 ]
5621+ num_warps = np .prod (src_layout .warps_per_cta )
55615622
55625623 def do_test (src_layout , dst_layout ):
55635624 layouts = f"""
@@ -5593,7 +5654,7 @@ def do_test(src_layout, dst_layout):
55935654 x = to_triton (numpy_random ((M , N ), dtype_str = dtype ), device = device )
55945655 z = torch .empty_like (x )
55955656
5596- temp_file = tmp_path / "test_convertmma2mma .ttgir"
5657+ temp_file = tmp_path / "test_convert_mma2mma .ttgir"
55975658 temp_file .write_text (ir )
55985659 kernel = triton .compile (str (temp_file ))
55995660
0 commit comments