@@ -5433,6 +5433,97 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t
54335433 assert torch .equal (z , x )
54345434
54355435
5436+ layouts_3d = [
5437+ BlockedLayout ([4 , 4 , 1 ], [1 , 8 , THREADS_PER_WARP // 8 ], [2 , 2 , 1 ], [2 , 1 , 0 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
5438+ BlockedLayout ([1 , 1 , 4 ], [8 , THREADS_PER_WARP // 8 , 1 ], [2 , 1 , 2 ], [1 , 2 , 0 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
5439+ DotOperandLayout (parent = MmaLayout ([2 , 0 ], [4 , 1 , 1 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [2 , 1 , 0 ], [1 , 16 , 8 ]), op_idx = 0 ,
5440+ k_width = 1 ),
5441+ ]
5442+
5443+ shared_layout_3d = [
5444+ SharedLayout (1 , 1 , 1 , [2 , 1 , 0 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
5445+ SharedLayout (4 , 2 , 4 , [1 , 2 , 0 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
5446+ SharedLayout (8 , 2 , 4 , [0 , 2 , 1 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
5447+ SharedLayout (4 , 2 , 1 , [2 , 0 , 1 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 2 ]),
5448+ ]
5449+
5450+
5451+ @pytest .mark .parametrize ("M, N, K" , [[8 , 16 , 32 ]])
5452+ @pytest .mark .parametrize ("shared_layout" , shared_layout_3d )
5453+ @pytest .mark .parametrize ("dist_layout" , layouts_3d )
5454+ def test_local_load_store (M , N , K , dist_layout , shared_layout , device , tmp_path : pathlib .Path ):
5455+ layouts = f"""
5456+ #dist = { dist_layout }
5457+ #shared = { shared_layout }
5458+ """
5459+ ir = layouts + f"""
5460+ module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = { THREADS_PER_WARP } : i32}} {{
5461+ tt.func public @kernel(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{
5462+ %cst = arith.constant dense<{ K } > : tensor<1x{ N } x1xi32, #dist>
5463+ %cst_0 = arith.constant dense<{ K * N } > : tensor<{ M } x1x1xi32, #dist>
5464+ %cst_1 = arith.constant dense<{ K * N } > : tensor<{ M } x1x1xi32, #dist>
5465+ %cst_2 = arith.constant dense<{ K } > : tensor<1x{ N } x1xi32, #dist>
5466+ %0 = tt.make_range {{end = { K } : i32, start = 0 : i32}} : tensor<{ K } xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>>
5467+ %1 = tt.expand_dims %0 {{axis = 0 : i32}} : tensor<{ K } xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{ K } xi32, #ttg.slice<{{dim = 1, parent = #dist}}>>
5468+ %2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<1x{ K } xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{ K } xi32, #dist>
5469+ %3 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1x1x{ K } x!tt.ptr<i32>, #dist>
5470+ %4 = tt.addptr %3, %2 : tensor<1x1x{ K } x!tt.ptr<i32>, #dist>, tensor<1x1x{ K } xi32, #dist>
5471+ %5 = tt.make_range {{end = { N } : i32, start = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>>
5472+ %6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<1x{ N } xi32, #ttg.slice<{{dim = 2, parent = #dist}}>>
5473+ %7 = tt.expand_dims %6 {{axis = 2 : i32}} : tensor<1x{ N } xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<1x{ N } x1xi32, #dist>
5474+ %8 = arith.muli %7, %cst_2 : tensor<1x{ N } x1xi32, #dist>
5475+ %9 = tt.broadcast %4 : tensor<1x1x{ K } x!tt.ptr<i32>, #dist> -> tensor<1x{ N } x{ K } x!tt.ptr<i32>, #dist>
5476+ %10 = tt.broadcast %8 : tensor<1x{ N } x1xi32, #dist> -> tensor<1x{ N } x{ K } xi32, #dist>
5477+ %11 = tt.addptr %9, %10 : tensor<1x{ N } x{ K } x!tt.ptr<i32>, #dist>, tensor<1x{ N } x{ K } xi32, #dist>
5478+ %12 = tt.make_range {{end = { M } : i32, start = 0 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>>
5479+ %13 = tt.expand_dims %12 {{axis = 1 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<{ M } x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>>
5480+ %14 = tt.expand_dims %13 {{axis = 2 : i32}} : tensor<{ M } x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<{ M } x1x1xi32, #dist>
5481+ %15 = arith.muli %14, %cst_1 : tensor<{ M } x1x1xi32, #dist>
5482+ %16 = tt.broadcast %11 : tensor<1x{ N } x{ K } x!tt.ptr<i32>, #dist> -> tensor<{ M } x{ N } x{ K } x!tt.ptr<i32>, #dist>
5483+ %17 = tt.broadcast %15 : tensor<{ M } x1x1xi32, #dist> -> tensor<{ M } x{ N } x{ K } xi32, #dist>
5484+ %18 = tt.addptr %16, %17 : tensor<{ M } x{ N } x{ K } x!tt.ptr<i32>, #dist>, tensor<{ M } x{ N } x{ K } xi32, #dist>
5485+ %19 = tt.load %18 : tensor<{ M } x{ N } x{ K } x!tt.ptr<i32>, #dist>
5486+ %20 = ttg.local_alloc %19 : (tensor<{ M } x{ N } x{ K } xi32, #dist>) -> !ttg.memdesc<{ M } x{ N } x{ K } xi32, #shared, #ttg.shared_memory>
5487+ %21 = ttg.local_load %20 : !ttg.memdesc<{ M } x{ N } x{ K } xi32, #shared, #ttg.shared_memory> -> tensor<{ M } x{ N } x{ K } xi32, #dist>
5488+ %22 = tt.make_range {{end = { K } : i32, start = 0 : i32}} : tensor<{ K } xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>>
5489+ %23 = tt.expand_dims %22 {{axis = 0 : i32}} : tensor<{ K } xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{ K } xi32, #ttg.slice<{{dim = 1, parent = #dist}}>>
5490+ %24 = tt.expand_dims %23 {{axis = 1 : i32}} : tensor<1x{ K } xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{ K } xi32, #dist>
5491+ %25 = tt.splat %arg1 : !tt.ptr<i32> -> tensor<1x1x{ K } x!tt.ptr<i32>, #dist>
5492+ %26 = tt.addptr %25, %24 : tensor<1x1x{ K } x!tt.ptr<i32>, #dist>, tensor<1x1x{ K } xi32, #dist>
5493+ %27 = tt.make_range {{end = { N } : i32, start = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>>
5494+ %28 = tt.expand_dims %27 {{axis = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<1x{ N } xi32, #ttg.slice<{{dim = 2, parent = #dist}}>>
5495+ %29 = tt.expand_dims %28 {{axis = 2 : i32}} : tensor<1x{ N } xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<1x{ N } x1xi32, #dist>
5496+ %30 = arith.muli %29, %cst : tensor<1x{ N } x1xi32, #dist>
5497+ %31 = tt.broadcast %26 : tensor<1x1x{ K } x!tt.ptr<i32>, #dist> -> tensor<1x{ N } x{ K } x!tt.ptr<i32>, #dist>
5498+ %32 = tt.broadcast %30 : tensor<1x{ N } x1xi32, #dist> -> tensor<1x{ N } x{ K } xi32, #dist>
5499+ %33 = tt.addptr %31, %32 : tensor<1x{ N } x{ K } x!tt.ptr<i32>, #dist>, tensor<1x{ N } x{ K } xi32, #dist>
5500+ %34 = tt.make_range {{end = { M } : i32, start = 0 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>>
5501+ %35 = tt.expand_dims %34 {{axis = 1 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<{ M } x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>>
5502+ %36 = tt.expand_dims %35 {{axis = 2 : i32}} : tensor<{ M } x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<{ M } x1x1xi32, #dist>
5503+ %37 = arith.muli %36, %cst_0 : tensor<{ M } x1x1xi32, #dist>
5504+ %38 = tt.broadcast %33 : tensor<1x{ N } x{ K } x!tt.ptr<i32>, #dist> -> tensor<{ M } x{ N } x{ K } x!tt.ptr<i32>, #dist>
5505+ %39 = tt.broadcast %37 : tensor<{ M } x1x1xi32, #dist> -> tensor<{ M } x{ N } x{ K } xi32, #dist>
5506+ %40 = tt.addptr %38, %39 : tensor<{ M } x{ N } x{ K } x!tt.ptr<i32>, #dist>, tensor<{ M } x{ N } x{ K } xi32, #dist>
5507+ tt.store %40, %21 : tensor<{ M } x{ N } x{ K } x!tt.ptr<i32>, #dist>
5508+ tt.return
5509+ }}
5510+ }}
5511+ """
5512+
5513+ if is_xpu () and isinstance (dist_layout , DotOperandLayout ) and isinstance (dist_layout .parent , MmaLayout ):
5514+ pytest .xfail ("DotOperandLayout with MmaLayout is not supported in XPU" )
5515+
5516+ x = torch .arange (0 , M * N * K , device = device , dtype = torch .int32 ).reshape (M , N , K )
5517+ z = torch .empty_like (x , device = device )
5518+
5519+ temp_file = tmp_path / "test_local_load_store.ttgir"
5520+ temp_file .write_text (ir )
5521+ kernel = triton .compile (str (temp_file ))
5522+
5523+ kernel [(1 , 1 , 1 )](x , z )
5524+ assert torch .equal (z , x )
5525+
5526+
54365527mma_pairs = [
54375528 [
54385529 MmaLayout ((2 , 0 ), [1 , 4 ], [1 , 1 ], [1 , 1 ], [0 , 1 ], [16 , 8 ]),
0 commit comments