@@ -25,80 +25,145 @@ def __str__(self):
2525 return f"#ttig.dpas<{{repeatCount={ self .repeatCount } , systolicDepth={ self .systolic_depth } , executionSize = { self .execution_size } , opsPerChan = { self .ops_per_chan } , threadsPerWarp = { self .threads_per_warp } , warpsPerCTA={ self .warps_per_cta } , repCluster={ self .rep_cluster } }}>"
2626
2727
28+ class DotOperandLayout :
29+
30+ def __init__ (self , parent , op_idx , k_width ):
31+ self .parent = parent
32+ self .op_idx = op_idx
33+ self .k_width = k_width
34+ self .threads_per_warp = parent .threads_per_warp
35+
36+ def __str__ (self ):
37+ return f"#ttg.dot_op<{{parent={ self .parent } , opIdx={ self .op_idx } , kWidth={ self .k_width } }}>"
38+
39+
40+ class SliceLayout :
41+
42+ def __init__ (self , dim , parent ):
43+ self .dim = dim
44+ self .parent = parent
45+ self .threads_per_warp = parent .threads_per_warp
46+
47+ def __str__ (self ):
48+ return f"#ttg.slice<{{dim = { self .dim } , parent = { self .parent } }}>"
49+
50+
51+ class BlockedLayout :
52+
53+ def __init__ (self , size_per_thread , threads_per_warp , warps_per_cta , order , ctas_per_cga = [1 , 1 ],
54+ cta_split_num = [1 , 1 ], cta_order = [0 , 1 ]):
55+ self .sz_per_thread = size_per_thread
56+ self .threads_per_warp = threads_per_warp
57+ self .warps_per_cta = warps_per_cta
58+ self .order = order
59+ self .ctas_per_cga = ctas_per_cga
60+ self .cta_split_num = cta_split_num
61+ self .cta_order = cta_order
62+
63+ def __str__ (self ):
64+ return f"#ttg.blocked<{{sizePerThread={ self .sz_per_thread } , threadsPerWarp={ self .threads_per_warp } , warpsPerCTA={ self .warps_per_cta } , order={ self .order } , CTAsPerCGA={ self .ctas_per_cga } , CTASplitNum={ self .cta_split_num } , CTAOrder={ self .cta_order } }}>"
65+
66+
2867def warps_per_cta (layout ):
29- return layout .warps_per_cta
68+ if isinstance (layout , (SliceLayout , DotOperandLayout )):
69+ return warps_per_cta (layout .parent )
70+ else :
71+ return layout .warps_per_cta
3072
3173
3274layouts = [
33- # Layout for Xe
75+ BlockedLayout ([1 , 1 ], [2 , 16 ], [4 , 1 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
76+ # DPAS layout
3477 DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 16 , ops_per_chan = 4 , threads_per_warp = 16 ,
3578 warps_per_cta = [1 , 4 ], rep_cluster = [1 , 2 ]),
3679 DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 16 , ops_per_chan = 2 , threads_per_warp = 16 ,
3780 warps_per_cta = [8 , 4 ], rep_cluster = [4 , 2 ]),
3881 DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 16 , ops_per_chan = 1 , threads_per_warp = 16 ,
3982 warps_per_cta = [8 , 4 ], rep_cluster = [1 , 1 ]),
83+ DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 8 , ops_per_chan = 1 , threads_per_warp = 32 ,
84+ warps_per_cta = [4 , 1 ], rep_cluster = [1 , 1 ]),
85+ DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 16 , ops_per_chan = 2 , threads_per_warp = 32 ,
86+ warps_per_cta = [2 , 2 ], rep_cluster = [1 , 1 ]),
87+ DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 8 , ops_per_chan = 4 , threads_per_warp = 32 ,
88+ warps_per_cta = [4 , 1 ], rep_cluster = [1 , 1 ]),
89+ # DotOp A
90+ DotOperandLayout (
91+ parent = DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 16 , ops_per_chan = 2 , threads_per_warp = 32 ,
92+ warps_per_cta = [2 , 2 ], rep_cluster = [1 , 1 ]), op_idx = 0 , k_width = 1 ),
93+ DotOperandLayout (
94+ parent = DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 8 , ops_per_chan = 1 , threads_per_warp = 16 ,
95+ warps_per_cta = [2 , 2 ], rep_cluster = [1 , 1 ]), op_idx = 0 , k_width = 1 ),
96+ # DotOp B
97+ DotOperandLayout (
98+ parent = DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 16 , ops_per_chan = 1 , threads_per_warp = 16 ,
99+ warps_per_cta = [2 , 2 ], rep_cluster = [1 , 1 ]), op_idx = 1 , k_width = 1 ),
100+ DotOperandLayout (
101+ parent = DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 16 , ops_per_chan = 2 , threads_per_warp = 16 ,
102+ warps_per_cta = [2 , 2 ], rep_cluster = [1 , 1 ]), op_idx = 1 , k_width = 2 ),
103+ DotOperandLayout (
104+ parent = DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 16 , ops_per_chan = 4 , threads_per_warp = 16 ,
105+ warps_per_cta = [2 , 2 ], rep_cluster = [1 , 1 ]), op_idx = 1 , k_width = 4 ),
106+ DotOperandLayout (
107+ parent = DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 16 , ops_per_chan = 1 , threads_per_warp = 32 ,
108+ warps_per_cta = [2 , 2 ], rep_cluster = [1 , 1 ]), op_idx = 1 , k_width = 1 ),
109+ # Slice layout
110+ SliceLayout (dim = 1 , parent = BlockedLayout ([1 , 4 , 1 ], [2 , 1 , 16 ], [2 , 1 , 2 ], [2 , 1 , 0 ], [1 , 1 , 1 ], [1 , 1 , 1 ],
111+ [0 , 1 , 2 ])),
40112]
41113
42114
43- @pytest .mark .parametrize ("M, N" , [[M , N ] for M , N in itertools .product ([32 , 64 , 128 , 256 ], [32 , 64 , 128 , 256 ])])
115+ @pytest .mark .parametrize ("M, N" , [[M , N ] for M , N in itertools .product ([32 , 64 , 128 ], [64 , 128 ])])
44116@pytest .mark .parametrize ("dtype_str" , ["float32" , "float16" , "int8" ])
45117@pytest .mark .parametrize ("layout" , layouts )
118+ @pytest .mark .parametrize ("block_ptr" , [True , False ])
46119@pytest .mark .skipif (not is_xpu (), reason = "Block store tests are specific to the XPU backend" )
47- def test_tensor_pointer_block_store (M , N , dtype_str , layout , device , tmp_path : pathlib .Path ):
120+ def test_block_store (M , N , dtype_str , layout , block_ptr , device , tmp_path : pathlib .Path ):
48121
49122 warps = warps_per_cta (layout )
50123 num_warps = int (np .prod (warps ))
51124 threads_per_warp = layout .threads_per_warp
52- ops_per_chan = layout .ops_per_chan
53- A_width = 1 if ops_per_chan == 1 else ops_per_chan // 2
54- B_width = ops_per_chan
125+ threads_per_warp = int (np .prod (threads_per_warp ))
55126
56127 ty = {"float32" : "f32" , "float16" : "f16" , "bfloat16" : "i16" , "int8" : "i8" }[dtype_str ]
57128
58129 support_block_io = torch .xpu .get_device_capability ()['has_subgroup_2d_block_io' ]
59130
131+ if block_ptr :
132+ store_ops = f"""
133+ %M_i64 = arith.constant { M } : i64
134+ %N_i64 = arith.constant { N } : i64
135+ %c1_i64 = arith.constant 1 : i64
136+ %c0_i32 = arith.constant 0 : i32
137+
138+ %blk_ptr = tt.make_tensor_ptr %dst, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{ M } x{ N } x{ ty } , #layout>>
139+ tt.store %blk_ptr, %store_val {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>}} : !tt.ptr<tensor<{ M } x{ N } x{ ty } , #layout>>
140+ """
141+ else :
142+ store_ops = f"""
143+ %12 = tt.splat %dst : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
144+ %13 = tt.addptr %12, %8 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
145+ tt.store %13, %store_val {{ttig.block_io = "row_major"}} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
146+ """
147+
60148 ir = f"""
61- #mma = { layout }
62- #dot_a = #ttg.dot_op<{{opIdx = 0, parent = #mma, kWidth = { A_width } }}>
63- #dot_b = #ttg.dot_op<{{opIdx = 1, parent = #mma, kWidth = { B_width } }}>
149+ #layout = { layout }
64150 module attributes {{{ "ttig.support_sg_2d_block," if support_block_io else "" } "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = { num_warps } : i32, ttg.target = "xpu", "ttg.threads-per-warp" = { threads_per_warp } : i32}} {{
65- tt.func public @tensor_pointer_block_store(%arg0: !tt.ptr<{ ty } > {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<{ ty } > {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ ty } > {{tt.divisibility = 16: i32}}, %arg3: !tt.ptr<{ ty } > {{tt.divisibility = 16: i32}}) {{
66-
67- // A matrix
68- %stride_a = arith.constant dense<{ N } > : tensor<{ M } x1xi32, #dot_a>
69- %1 = tt.make_range {{end = { M } : i32, start = 0 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #dot_a}}>>
70- %2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #dot_a}}>> -> tensor<{ M } x1xi32, #dot_a>
71- %3 = arith.muli %2, %stride_a : tensor<{ M } x1xi32, #dot_a>
72- %4 = tt.make_range {{end = { N } : i32, start = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #dot_a}}>>
73- %5 = tt.expand_dims %4 {{axis = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #dot_a}}>> -> tensor<1x{ N } xi32, #dot_a>
74- %6 = tt.broadcast %3 : tensor<{ M } x1xi32, #dot_a> -> tensor<{ M } x{ N } xi32, #dot_a>
75- %7 = tt.broadcast %5 : tensor<1x{ N } xi32, #dot_a> -> tensor<{ M } x{ N } xi32, #dot_a>
76- %8 = arith.addi %6, %7 : tensor<{ M } x{ N } xi32, #dot_a>
77-
78- %9 = tt.splat %arg0 : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_a>
79- %10 = tt.addptr %9, %8 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_a>, tensor<{ M } x{ N } xi32, #dot_a>
80- %11 = tt.load %10 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_a>
81- %12 = tt.splat %arg1 : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_a>
82- %13 = tt.addptr %12, %8 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_a>, tensor<{ M } x{ N } xi32, #dot_a>
83- tt.store %13, %11 {{ttig.block_io = "row_major"}} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_a>
84-
85- // B matrix
86- %stride_b = arith.constant dense<{ N } > : tensor<{ M } x1xi32, #dot_b>
87- %21 = tt.make_range {{end = { M } : i32, start = 0 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #dot_b}}>>
88- %22 = tt.expand_dims %21 {{axis = 1 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #dot_b}}>> -> tensor<{ M } x1xi32, #dot_b>
89- %23 = arith.muli %22, %stride_b : tensor<{ M } x1xi32, #dot_b>
90- %24 = tt.make_range {{end = { N } : i32, start = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #dot_b}}>>
91- %25 = tt.expand_dims %24 {{axis = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #dot_b}}>> -> tensor<1x{ N } xi32, #dot_b>
92- %26 = tt.broadcast %23 : tensor<{ M } x1xi32, #dot_b> -> tensor<{ M } x{ N } xi32, #dot_b>
93- %27 = tt.broadcast %25 : tensor<1x{ N } xi32, #dot_b> -> tensor<{ M } x{ N } xi32, #dot_b>
94- %28 = arith.addi %26, %27 : tensor<{ M } x{ N } xi32, #dot_b>
95-
96- %29 = tt.splat %arg2 : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_b>
97- %30 = tt.addptr %29, %28 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_b>, tensor<{ M } x{ N } xi32, #dot_b>
98- %31 = tt.load %30 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_b>
99- %32 = tt.splat %arg3 : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_b>
100- %33 = tt.addptr %32, %28 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_b>, tensor<{ M } x{ N } xi32, #dot_b>
101- tt.store %33, %31 {{ttig.block_io = "row_major"}} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_b>
151+ tt.func public @block_store(%src: !tt.ptr<{ ty } > {{tt.divisibility = 16 : i32}}, %dst: !tt.ptr<{ ty } > {{tt.divisibility = 16 : i32}}) {{
152+
153+ %stride = arith.constant dense<{ N } > : tensor<{ M } x1xi32, #layout>
154+ %1 = tt.make_range {{end = { M } : i32, start = 0 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #layout}}>>
155+ %2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #layout}}>> -> tensor<{ M } x1xi32, #layout>
156+ %3 = arith.muli %2, %stride : tensor<{ M } x1xi32, #layout>
157+ %4 = tt.make_range {{end = { N } : i32, start = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #layout}}>>
158+ %5 = tt.expand_dims %4 {{axis = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #layout}}>> -> tensor<1x{ N } xi32, #layout>
159+ %6 = tt.broadcast %3 : tensor<{ M } x1xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
160+ %7 = tt.broadcast %5 : tensor<1x{ N } xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
161+ %8 = arith.addi %6, %7 : tensor<{ M } x{ N } xi32, #layout>
162+ %9 = tt.splat %src : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
163+ %10 = tt.addptr %9, %8 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
164+ %store_val = tt.load %10 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
165+
166+ { store_ops }
102167
103168 tt.return
104169 }}
@@ -112,11 +177,10 @@ def test_tensor_pointer_block_store(M, N, dtype_str, layout, device, tmp_path: p
112177 a = torch .randint (low = - 127 , high = 128 , size = (M , N ), dtype = torch_dtype , device = device )
113178
114179 x = torch .empty_like (a )
115- y = torch .empty_like (a )
116180
117- temp_file = tmp_path / "test_tensor_pointer_block_store .ttgir"
181+ temp_file = tmp_path / "test_block_store .ttgir"
118182 temp_file .write_text (ir )
119183 kernel = triton .compile (str (temp_file ))
120184
121- kernel [(1 , 1 , 1 )](a , x , a , y )
122- assert torch .equal (a , x ) and torch . equal ( a , y )
185+ kernel [(1 , 1 , 1 )](a , x )
186+ assert torch .equal (a , x )
0 commit comments