Skip to content

Commit bd11640

Browse files
[TEST] Enhance test_block_io.py (#4907)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 6cb654f commit bd11640

File tree

3 files changed

+32
-20
lines changed

3 files changed

+32
-20
lines changed

python/test/unit/intel/test_block_store.py renamed to python/test/unit/intel/test_block_io.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,10 @@ def warps_per_cta(layout):
118118
@pytest.mark.parametrize("M, N", [[M, N] for M, N in itertools.product([32, 64, 128], [64, 128])])
119119
@pytest.mark.parametrize("dtype_str", ["float32", "float16", "int8"])
120120
@pytest.mark.parametrize("layout", layouts)
121-
@pytest.mark.parametrize("block_ptr", [True, False])
121+
@pytest.mark.parametrize("load_block_ptr, store_block_ptr", [(True, True), (False, False), (True, False),
122+
(False, True)])
122123
@pytest.mark.skipif(not is_xpu(), reason="Block store tests are specific to the XPU backend")
123-
def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathlib.Path):
124+
def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, device, tmp_path: pathlib.Path):
124125

125126
warps = warps_per_cta(layout)
126127
num_warps = int(np.prod(warps))
@@ -131,41 +132,50 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl
131132

132133
support_block_io = torch.xpu.get_device_capability()['has_subgroup_2d_block_io']
133134

134-
if block_ptr:
135+
if load_block_ptr:
136+
load_ops = f"""
137+
%src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{M}x{N}x{ty}, #layout>>
138+
%store_val = tt.load %src_ptr {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{M}x{N}x{ty}, #layout>>
139+
"""
140+
else:
141+
load_ops = f"""
142+
%src_base = tt.splat %src : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
143+
%src_ptr = tt.addptr %src_base, %row_major_off : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
144+
%store_val = tt.load %src_ptr {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
145+
"""
146+
if store_block_ptr:
135147
store_ops = f"""
136-
%M_i64 = arith.constant {M} : i64
137-
%N_i64 = arith.constant {N} : i64
138-
%c1_i64 = arith.constant 1 : i64
139-
%c0_i32 = arith.constant 0 : i32
140-
141148
%blk_ptr = tt.make_tensor_ptr %dst, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{M}x{N}x{ty}, #layout>>
142149
tt.store %blk_ptr, %store_val {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>}} : !tt.ptr<tensor<{M}x{N}x{ty}, #layout>>
143150
"""
144151
else:
145152
store_ops = f"""
146-
%12 = tt.splat %dst : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
147-
%13 = tt.addptr %12, %8 : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
148-
tt.store %13, %store_val {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
153+
%dst_base = tt.splat %dst : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
154+
%dst_ptr = tt.addptr %dst_base, %row_major_off : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
155+
tt.store %dst_ptr, %store_val {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
149156
"""
150157

151158
ir = f"""
152159
#layout = {layout}
153160
module attributes {{{"ttig.support_sg_2d_block," if support_block_io else ""} "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = {num_warps} : i32, ttg.target = "xpu", "ttg.threads-per-warp" = {threads_per_warp} : i32}} {{
154161
tt.func public @block_store(%src: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %dst: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}) {{
155162
156-
%stride = arith.constant dense<{N}> : tensor<{M}x1xi32, #layout>
163+
%M_i64 = arith.constant {M} : i64
164+
%N_i64 = arith.constant {N} : i64
165+
%c1_i64 = arith.constant 1 : i64
166+
%c0_i32 = arith.constant 0 : i32
167+
168+
%stride_N = arith.constant dense<{N}> : tensor<{M}x1xi32, #layout>
157169
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #layout}}>>
158170
%2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #layout}}>> -> tensor<{M}x1xi32, #layout>
159-
%3 = arith.muli %2, %stride : tensor<{M}x1xi32, #layout>
171+
%row_stride = arith.muli %2, %stride_N : tensor<{M}x1xi32, #layout>
160172
%4 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #layout}}>>
161173
%5 = tt.expand_dims %4 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #layout}}>> -> tensor<1x{N}xi32, #layout>
162-
%6 = tt.broadcast %3 : tensor<{M}x1xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
174+
%6 = tt.broadcast %row_stride : tensor<{M}x1xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
163175
%7 = tt.broadcast %5 : tensor<1x{N}xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
164-
%8 = arith.addi %6, %7 : tensor<{M}x{N}xi32, #layout>
165-
%9 = tt.splat %src : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
166-
%10 = tt.addptr %9, %8 : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
167-
%store_val = tt.load %10 : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
176+
%row_major_off = arith.addi %6, %7 : tensor<{M}x{N}xi32, #layout>
168177
178+
{load_ops}
169179
{store_ops}
170180
171181
tt.return
@@ -181,7 +191,7 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl
181191

182192
x = torch.empty_like(a)
183193

184-
temp_file = tmp_path / "test_block_store.ttgir"
194+
temp_file = tmp_path / "test_block_io.ttgir"
185195
temp_file.write_text(ir)
186196
kernel = triton.compile(str(temp_file))
187197

scripts/skiplist/lts/intel.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
python/test/unit/intel/test_block_load.py::test_block_load_dpas_layout
2-
python/test/unit/intel/test_block_store.py::test_block_store
2+
python/test/unit/intel/test_block_io.py::test_block_io

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2119,6 +2119,8 @@ struct LoadOpToBlockIOConversion
21192119
return failure();
21202120
numOperandsPer2DLoadN =
21212121
std::min(numOperandsPer2DLoadN, MAX_WIDTH / totalBytesPerRowPerDPASOp);
2122+
// vBlocks has HW limitation of 4.
2123+
numOperandsPer2DLoadN = std::min(numOperandsPer2DLoadN, 4u);
21222124

21232125
tileHeight = instHeight * numOperandsPer2DLoadM;
21242126
tileWidth = instWidth;

0 commit comments

Comments
 (0)