@@ -5383,6 +5383,94 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t
53835383 assert torch .equal (z , x )
53845384
53855385
5386+ layouts_3d = [
5387+ BlockedLayout ([4 , 4 , 1 ], [1 , 8 , THREADS_PER_WARP // 8 ], [2 , 2 , 1 ], [2 , 1 , 0 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
5388+ BlockedLayout ([1 , 1 , 4 ], [8 , THREADS_PER_WARP // 8 , 1 ], [2 , 1 , 2 ], [1 , 2 , 0 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
5389+ DotOperandLayout (parent = MmaLayout ([2 , 0 ], [4 , 1 , 1 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [2 , 1 , 0 ], [1 , 16 , 8 ]), op_idx = 0 ,
5390+ k_width = 1 ),
5391+ ]
5392+
5393+ shared_layout_3d = [
5394+ SharedLayout (1 , 1 , 1 , [2 , 1 , 0 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
5395+ SharedLayout (4 , 2 , 4 , [1 , 2 , 0 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
5396+ SharedLayout (8 , 2 , 4 , [0 , 2 , 1 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
5397+ SharedLayout (4 , 2 , 1 , [2 , 0 , 1 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
5398+ ]
5399+
5400+
5401+ @pytest .mark .parametrize ("M, N, K" , [[8 , 16 , 32 ]])
5402+ @pytest .mark .parametrize ("shared_layout" , shared_layout_3d )
5403+ @pytest .mark .parametrize ("dist_layout" , layouts_3d )
5404+ def test_local_load_store (M , N , K , dist_layout , shared_layout , device , tmp_path : pathlib .Path ):
5405+ layouts = f"""
5406+ #dist = { dist_layout }
5407+ #shared = { shared_layout }
5408+ """
5409+ ir = layouts + f"""
5410+ module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = { THREADS_PER_WARP } : i32}} {{
5411+ tt.func public @kernel(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{
5412+ %cst = arith.constant dense<{ K } > : tensor<1x{ N } x1xi32, #dist>
5413+ %cst_0 = arith.constant dense<{ K * N } > : tensor<{ M } x1x1xi32, #dist>
5414+ %cst_1 = arith.constant dense<{ K * N } > : tensor<{ M } x1x1xi32, #dist>
5415+ %cst_2 = arith.constant dense<{ K } > : tensor<1x{ N } x1xi32, #dist>
5416+ %0 = tt.make_range {{end = { K } : i32, start = 0 : i32}} : tensor<{ K } xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>>
5417+ %1 = tt.expand_dims %0 {{axis = 0 : i32}} : tensor<{ K } xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{ K } xi32, #ttg.slice<{{dim = 1, parent = #dist}}>>
5418+ %2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<1x{ K } xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{ K } xi32, #dist>
5419+ %3 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1x1x{ K } x!tt.ptr<i32>, #dist>
5420+ %4 = tt.addptr %3, %2 : tensor<1x1x{ K } x!tt.ptr<i32>, #dist>, tensor<1x1x{ K } xi32, #dist>
5421+ %5 = tt.make_range {{end = { N } : i32, start = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>>
5422+ %6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<1x{ N } xi32, #ttg.slice<{{dim = 2, parent = #dist}}>>
5423+ %7 = tt.expand_dims %6 {{axis = 2 : i32}} : tensor<1x{ N } xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<1x{ N } x1xi32, #dist>
5424+ %8 = arith.muli %7, %cst_2 : tensor<1x{ N } x1xi32, #dist>
5425+ %9 = tt.broadcast %4 : tensor<1x1x{ K } x!tt.ptr<i32>, #dist> -> tensor<1x{ N } x{ K } x!tt.ptr<i32>, #dist>
5426+ %10 = tt.broadcast %8 : tensor<1x{ N } x1xi32, #dist> -> tensor<1x{ N } x{ K } xi32, #dist>
5427+ %11 = tt.addptr %9, %10 : tensor<1x{ N } x{ K } x!tt.ptr<i32>, #dist>, tensor<1x{ N } x{ K } xi32, #dist>
5428+ %12 = tt.make_range {{end = { M } : i32, start = 0 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>>
5429+ %13 = tt.expand_dims %12 {{axis = 1 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<{ M } x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>>
5430+ %14 = tt.expand_dims %13 {{axis = 2 : i32}} : tensor<{ M } x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<{ M } x1x1xi32, #dist>
5431+ %15 = arith.muli %14, %cst_1 : tensor<{ M } x1x1xi32, #dist>
5432+ %16 = tt.broadcast %11 : tensor<1x{ N } x{ K } x!tt.ptr<i32>, #dist> -> tensor<{ M } x{ N } x{ K } x!tt.ptr<i32>, #dist>
5433+ %17 = tt.broadcast %15 : tensor<{ M } x1x1xi32, #dist> -> tensor<{ M } x{ N } x{ K } xi32, #dist>
5434+ %18 = tt.addptr %16, %17 : tensor<{ M } x{ N } x{ K } x!tt.ptr<i32>, #dist>, tensor<{ M } x{ N } x{ K } xi32, #dist>
5435+ %19 = tt.load %18 : tensor<{ M } x{ N } x{ K } x!tt.ptr<i32>, #dist>
5436+ %20 = ttg.local_alloc %19 : (tensor<{ M } x{ N } x{ K } xi32, #dist>) -> !ttg.memdesc<{ M } x{ N } x{ K } xi32, #shared, #ttg.shared_memory>
5437+ %21 = ttg.local_load %20 : !ttg.memdesc<{ M } x{ N } x{ K } xi32, #shared, #ttg.shared_memory> -> tensor<{ M } x{ N } x{ K } xi32, #dist>
5438+ %22 = tt.make_range {{end = { K } : i32, start = 0 : i32}} : tensor<{ K } xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>>
5439+ %23 = tt.expand_dims %22 {{axis = 0 : i32}} : tensor<{ K } xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{ K } xi32, #ttg.slice<{{dim = 1, parent = #dist}}>>
5440+ %24 = tt.expand_dims %23 {{axis = 1 : i32}} : tensor<1x{ K } xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{ K } xi32, #dist>
5441+ %25 = tt.splat %arg1 : !tt.ptr<i32> -> tensor<1x1x{ K } x!tt.ptr<i32>, #dist>
5442+ %26 = tt.addptr %25, %24 : tensor<1x1x{ K } x!tt.ptr<i32>, #dist>, tensor<1x1x{ K } xi32, #dist>
5443+ %27 = tt.make_range {{end = { N } : i32, start = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>>
5444+ %28 = tt.expand_dims %27 {{axis = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<1x{ N } xi32, #ttg.slice<{{dim = 2, parent = #dist}}>>
5445+ %29 = tt.expand_dims %28 {{axis = 2 : i32}} : tensor<1x{ N } xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<1x{ N } x1xi32, #dist>
5446+ %30 = arith.muli %29, %cst : tensor<1x{ N } x1xi32, #dist>
5447+ %31 = tt.broadcast %26 : tensor<1x1x{ K } x!tt.ptr<i32>, #dist> -> tensor<1x{ N } x{ K } x!tt.ptr<i32>, #dist>
5448+ %32 = tt.broadcast %30 : tensor<1x{ N } x1xi32, #dist> -> tensor<1x{ N } x{ K } xi32, #dist>
5449+ %33 = tt.addptr %31, %32 : tensor<1x{ N } x{ K } x!tt.ptr<i32>, #dist>, tensor<1x{ N } x{ K } xi32, #dist>
5450+ %34 = tt.make_range {{end = { M } : i32, start = 0 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>>
5451+ %35 = tt.expand_dims %34 {{axis = 1 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<{ M } x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>>
5452+ %36 = tt.expand_dims %35 {{axis = 2 : i32}} : tensor<{ M } x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<{ M } x1x1xi32, #dist>
5453+ %37 = arith.muli %36, %cst_0 : tensor<{ M } x1x1xi32, #dist>
5454+ %38 = tt.broadcast %33 : tensor<1x{ N } x{ K } x!tt.ptr<i32>, #dist> -> tensor<{ M } x{ N } x{ K } x!tt.ptr<i32>, #dist>
5455+ %39 = tt.broadcast %37 : tensor<{ M } x1x1xi32, #dist> -> tensor<{ M } x{ N } x{ K } xi32, #dist>
5456+ %40 = tt.addptr %38, %39 : tensor<{ M } x{ N } x{ K } x!tt.ptr<i32>, #dist>, tensor<{ M } x{ N } x{ K } xi32, #dist>
5457+ tt.store %40, %21 : tensor<{ M } x{ N } x{ K } x!tt.ptr<i32>, #dist>
5458+ tt.return
5459+ }}
5460+ }}
5461+ """
5462+
5463+ x = torch .arange (0 , M * N * K , device = device , dtype = torch .int32 ).reshape (M , N , K )
5464+ z = torch .empty_like (x , device = device )
5465+
5466+ temp_file = tmp_path / "test_local_load_store.ttgir"
5467+ temp_file .write_text (ir )
5468+ kernel = triton .compile (str (temp_file ))
5469+
5470+ kernel [(1 , 1 , 1 )](x , z )
5471+ assert torch .equal (z , x )
5472+
5473+
53865474mma_pairs = [
53875475 [
53885476 MmaLayout ((2 , 0 ), [1 , 4 ], [1 , 1 ], [1 , 1 ], [0 , 1 ], [16 , 8 ]),
0 commit comments