Skip to content

Commit 689448e

Browse files
committed
[LoadStoreOpToLLVM] Refactor block load lowering of tt.load with tensor pointer.
Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 8609010 commit 689448e

File tree

3 files changed

+619
-625
lines changed

3 files changed

+619
-625
lines changed

python/test/unit/intel/test_block_store.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -133,27 +133,35 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl
133133
support_block_io = torch.xpu.get_device_capability()['has_subgroup_2d_block_io']
134134

135135
if block_ptr:
136+
load_ops = f"""
137+
%src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{M}x{N}x{ty}, #layout>>
138+
%store_val = tt.load %src_ptr {{boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{M}x{N}x{ty}, #layout>>
139+
"""
136140
store_ops = f"""
137-
%M_i64 = arith.constant {M} : i64
138-
%N_i64 = arith.constant {N} : i64
139-
%c1_i64 = arith.constant 1 : i64
140-
%c0_i32 = arith.constant 0 : i32
141-
142-
%blk_ptr = tt.make_tensor_ptr %dst, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{M}x{N}x{ty}, #layout>>
143-
tt.store %blk_ptr, %store_val {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>}} : !tt.ptr<tensor<{M}x{N}x{ty}, #layout>>
141+
%dst_ptr = tt.make_tensor_ptr %dst, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{M}x{N}x{ty}, #layout>>
142+
tt.store %dst_ptr, %store_val {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>}} : !tt.ptr<tensor<{M}x{N}x{ty}, #layout>>
144143
"""
145144
else:
145+
load_ops = f"""
146+
%src_base = tt.splat %src : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
147+
%src_ptr = tt.addptr %src_base, %8 : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
148+
%store_val = tt.load %src_ptr {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
149+
"""
146150
store_ops = f"""
147-
%12 = tt.splat %dst : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
148-
%13 = tt.addptr %12, %8 : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
149-
tt.store %13, %store_val {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
151+
%dst_base = tt.splat %dst : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
152+
%dst_ptr = tt.addptr %dst_base, %8 : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
153+
tt.store %dst_ptr, %store_val {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
150154
"""
151155

152156
ir = f"""
153157
#layout = {layout}
154158
module attributes {{{"ttig.support_sg_2d_block," if support_block_io else ""} "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = {num_warps} : i32, ttg.target = "xpu", "ttg.threads-per-warp" = {threads_per_warp} : i32}} {{
155159
tt.func public @block_store(%src: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %dst: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}) {{
156160
161+
%M_i64 = arith.constant {M} : i64
162+
%N_i64 = arith.constant {N} : i64
163+
%c1_i64 = arith.constant 1 : i64
164+
%c0_i32 = arith.constant 0 : i32
157165
%stride = arith.constant dense<{N}> : tensor<{M}x1xi32, #layout>
158166
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #layout}}>>
159167
%2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #layout}}>> -> tensor<{M}x1xi32, #layout>
@@ -163,9 +171,7 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl
163171
%6 = tt.broadcast %3 : tensor<{M}x1xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
164172
%7 = tt.broadcast %5 : tensor<1x{N}xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
165173
%8 = arith.addi %6, %7 : tensor<{M}x{N}xi32, #layout>
166-
%9 = tt.splat %src : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
167-
%10 = tt.addptr %9, %8 : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
168-
%store_val = tt.load %10 : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
174+
{load_ops}
169175
170176
{store_ops}
171177
@@ -191,3 +197,5 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl
191197

192198
if support_block_io:
193199
assert 'spirv_Subgroup2DBlockStoreINTEL' in kernel.asm['llir'] or 'GenISA.LSC2DBlockWrite' in kernel.asm['llir']
200+
if not block_ptr:
201+
assert 'spirv_Subgroup2DBlockLoad' in kernel.asm['llir'] or 'GenISA.LSC2DBlockRead' in kernel.asm['llir']

0 commit comments

Comments
 (0)