@@ -119,9 +119,9 @@ def warps_per_cta(layout):
119119@pytest .mark .parametrize ("dtype_str" , ["float32" , "float16" , "int8" ])
120120@pytest .mark .parametrize ("layout" , layouts )
121121@pytest .mark .parametrize ("load_block_ptr, store_block_ptr" , [(True , True ), (False , False ), (True , False ),
122- (False , True )])
122+ @ pytest . mark . parametrize ( "transpose" , [ True , False ]) (False , True )])
123123@pytest .mark .skipif (not is_xpu (), reason = "Block store tests are specific to the XPU backend" )
124- def test_block_io (M , N , dtype_str , layout , load_block_ptr , store_block_ptr , device , tmp_path : pathlib .Path ):
124+ def test_block_io (M , N , dtype_str , layout , load_block_ptr , store_block_ptr , transpose , device , tmp_path : pathlib .Path ):
125125
126126 warps = warps_per_cta (layout )
127127 num_warps = int (np .prod (warps ))
@@ -132,16 +132,18 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi
132132
133133 support_block_io = torch .xpu .get_device_capability ()['has_subgroup_2d_block_io' ]
134134
135+ block_io = "\" column_major\" " if transpose else "\" row_major\" "
136+
135137 if load_block_ptr :
136138 load_ops = f"""
137- %src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], [% N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{ M } x{ N } x{ ty } , #layout>>
138- %store_val = tt.load %src_ptr {{ttig.block_io = "row_major" , boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{ M } x{ N } x{ ty } , #layout>>
139+ %src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], { "[%c1_i64, %M_i64]" if transpose else "[% N_i64, %c1_i64]" } , [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{ M } x{ N } x{ ty } , #layout>>
140+ %store_val = tt.load %src_ptr {{ttig.block_io = { block_io } , boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{ M } x{ N } x{ ty } , #layout>>
139141 """
140142 else :
141143 load_ops = f"""
142144 %src_base = tt.splat %src : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
143- %src_ptr = tt.addptr %src_base, % row_major_off : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
144- %store_val = tt.load %src_ptr {{ttig.block_io = "row_major" }} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
145+ %src_ptr = tt.addptr %src_base, { "%col_major_off" if transpose else "% row_major_off" } : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
146+ %store_val = tt.load %src_ptr {{ttig.block_io = { block_io } }} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
145147 """
146148 if store_block_ptr :
147149 store_ops = f"""
@@ -164,7 +166,6 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi
164166 %N_i64 = arith.constant { N } : i64
165167 %c1_i64 = arith.constant 1 : i64
166168 %c0_i32 = arith.constant 0 : i32
167-
168169 %stride_N = arith.constant dense<{ N } > : tensor<{ M } x1xi32, #layout>
169170 %1 = tt.make_range {{end = { M } : i32, start = 0 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #layout}}>>
170171 %2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #layout}}>> -> tensor<{ M } x1xi32, #layout>
@@ -175,6 +176,14 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi
175176 %7 = tt.broadcast %5 : tensor<1x{ N } xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
176177 %row_major_off = arith.addi %6, %7 : tensor<{ M } x{ N } xi32, #layout>
177178
179+
180+ %stride_M = arith.constant dense<{ M } > : tensor<1x{ N } xi32, #layout>
181+ %col_stride = arith.muli %5, %stride_M : tensor<1x{ N } xi32, #layout>
182+ %8 = tt.broadcast %2 : tensor<{ M } x1xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
183+ %9 = tt.broadcast %col_stride : tensor<1x{ N } xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
184+ %col_major_off = arith.addi %8, %9 : tensor<{ M } x{ N } xi32, #layout>
185+ { load_ops }
186+
178187 { load_ops }
179188 { store_ops }
180189
@@ -195,10 +204,14 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi
195204 temp_file .write_text (ir )
196205 kernel = triton .compile (str (temp_file ))
197206
207+ a = a .permute (1 , 0 ).contiguous ().permute (1 , 0 ) if transpose else a
208+
198209 kernel [(1 , 1 , 1 )](a , x )
199210 assert torch .equal (a , x )
200211
201212 if support_block_io :
202213 if not load_block_ptr :
203214 assert 'spirv_Subgroup2DBlockLoad' in kernel .asm ['llir' ] or 'GenISA.LSC2DBlockRead' in kernel .asm ['llir' ]
204215 assert 'spirv_Subgroup2DBlockStoreINTEL' in kernel .asm ['llir' ] or 'GenISA.LSC2DBlockWrite' in kernel .asm ['llir' ]
216+ if not block_ptr :
217+ assert 'spirv_Subgroup2DBlockLoad' in kernel .asm ['llir' ] or 'GenISA.LSC2DBlockRead' in kernel .asm ['llir' ]
0 commit comments