Skip to content

Commit 6b65fa3

Browse files
authored
[LoadStoreOpToLLVM] Enable the block store for tensor pointer (#4666)
Enable the block store for tensor pointer. Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 9122b6f commit 6b65fa3

File tree

5 files changed

+290
-54
lines changed

5 files changed

+290
-54
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
4848
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
4949
"TRITON_INTEL_DO_NOT_SINK_INSTR_ACROSS_RGN",
5050
"TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS",
51+
"TRITON_INTEL_ENABLE_BLOCK_IO_STORE_ON_REGULAR_PTR",
5152
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",
5253
"TRITON_INTEL_ENABLE_INSTR_SCHED",
5354
"TRITON_INTEL_FAST_MATH",

python/test/unit/intel/test_block_store.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import triton
1010
from triton._internal_testing import is_xpu
1111

12+
os.environ["TRITON_INTEL_ENABLE_BLOCK_IO_STORE_ON_REGULAR_PTR"] = "1"
1213
os.environ["TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS"] = "1"
1314

1415

@@ -188,5 +189,5 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl
188189
kernel[(1, 1, 1)](a, x)
189190
assert torch.equal(a, x)
190191

191-
if support_block_io and block_ptr:
192+
if support_block_io:
192193
assert 'spirv_Subgroup2DBlockStoreINTEL' in kernel.asm['llir'] or 'GenISA.LSC2DBlockWrite' in kernel.asm['llir']

test/TritonIntelGPU/blockptr_store.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32,
195195
%c1_i64 = arith.constant 1 : i64
196196
%0 = tt.make_tensor_ptr %base, [%width, %height], [%rowStride, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #dpas>>
197197

198-
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32
199198
// CHECK: %[[WARP_ID:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32
200199

201200
// CHECK: %[[offsetBaseY:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
@@ -206,6 +205,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32,
206205
// CHECK: %[[colStride:.*]] = llvm.extractvalue {{.*}}[5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
207206
// CHECK: %[[base:.*]] = llvm.extractvalue {{.*}}[6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
208207

208+
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32
209209
// CHECK: %[[rowStride_i32:.*]] = llvm.trunc %[[rowStride]] : i64 to i32
210210
// CHECK: %[[PITCH:.*]] = llvm.mul %[[rowStride_i32]], %[[C2]]
211211
// CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(f16, f16, {{.*}})>
@@ -263,14 +263,14 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
263263
// CHECK: %[[VAL_79:.*]] = llvm.insertvalue %[[rowStride]], %[[VAL_78]][4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
264264
// CHECK: %[[VAL_80:.*]] = llvm.insertvalue %[[CST_1]], %[[VAL_79]][5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
265265
// CHECK: %[[BLOCK_PTR:.*]] = llvm.insertvalue %[[base]], %[[VAL_80]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
266-
// CHECK: %[[SCALAR_BYTES:.*]] = llvm.mlir.constant(2 : i32) : i32
267266
// CHECK: %[[OFF_0:.*]] = llvm.extractvalue %[[BLOCK_PTR]][0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
268267
// CHECK: %[[OFF_1:.*]] = llvm.extractvalue %[[BLOCK_PTR]][1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
269268
// CHECK: %[[HEIGHT_i64:.*]] = llvm.extractvalue %[[BLOCK_PTR]][2] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
270269
// CHECK: %[[WIDTH_i64:.*]] = llvm.extractvalue %[[BLOCK_PTR]][3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
271270
// CHECK: %[[ROW_STRIDE_i64:.*]] = llvm.extractvalue %[[BLOCK_PTR]][4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
272271
// CHECK: %[[COL_STRIDE_i64:.*]] = llvm.extractvalue %[[BLOCK_PTR]][5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
273272
// CHECK: %[[BASE_PTR:.*]] = llvm.extractvalue %[[BLOCK_PTR]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
273+
// CHECK: %[[SCALAR_BYTES:.*]] = llvm.mlir.constant(2 : i32) : i32
274274
// CHECK: %[[WIDTH:.*]] = llvm.trunc %[[WIDTH_i64]] : i64 to i32
275275
// CHECK: %[[ROW_STRIDE:.*]] = llvm.trunc %[[ROW_STRIDE_i64]] : i64 to i32
276276
// CHECK: %[[WIDTH_IN_BYTES:.*]] = llvm.mul %[[WIDTH]], %[[SCALAR_BYTES]] : i32
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// RUN: env TRITON_INTEL_ENABLE_BLOCK_IO_STORE_ON_REGULAR_PTR=1 triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
2+
// RUN: env TRITON_INTEL_ENABLE_BLOCK_IO_STORE_ON_REGULAR_PTR=1 TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS=1 triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=ALL-LAYOUT
3+
4+
#blocked = #ttg.blocked<{sizePerThread = [1, 4, 2], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 8, 2], order = [2, 1, 0]}>
5+
#slice = #ttg.slice<{dim = 1, parent = #blocked}>
6+
module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
7+
// CHECK-LABEL: @regular_pointer_block_io
8+
tt.func public @regular_pointer_block_io(%arg0: !tt.ptr<i8>) {
9+
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #slice}>>
10+
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #slice}>> -> tensor<256x1xi32, #slice>
11+
%2 = arith.constant dense<64> : tensor<256x1xi32, #slice>
12+
%3 = arith.muli %1, %2 : tensor<256x1xi32, #slice>
13+
%4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #slice}>>
14+
%5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #slice}>> -> tensor<1x64xi32, #slice>
15+
%6 = tt.broadcast %3 : tensor<256x1xi32, #slice> -> tensor<256x64xi32, #slice>
16+
%7 = tt.broadcast %5 : tensor<1x64xi32, #slice> -> tensor<256x64xi32, #slice>
17+
%8 = arith.addi %6, %7 : tensor<256x64xi32, #slice>
18+
%9 = tt.splat %arg0 : !tt.ptr<i8> -> tensor<256x64x!tt.ptr<i8>, #slice>
19+
%addr = tt.addptr %9, %8 : tensor<256x64x!tt.ptr<i8>, #slice>, tensor<256x64xi32, #slice>
20+
%cst = arith.constant dense<0> : tensor<256x64xi8, #slice>
21+
// ALL-LAYOUT-COUNT-32: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 16, tile_width = 32, tile_height = 8, v_blocks = 1, cache_control = Default}
22+
tt.store %addr, %cst {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr<i8>, #slice>
23+
24+
tt.return
25+
}
26+
}
27+
28+
// -----
29+
30+
#blocked = #ttg.blocked<{sizePerThread = [4, 2], threadsPerWarp = [1, 32], warpsPerCTA = [8, 2], order = [1, 0]}>
31+
module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
32+
// CHECK-LABEL: @regular_pointer_block_io
33+
tt.func public @regular_pointer_block_io(%arg0: !tt.ptr<i8>) {
34+
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
35+
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked>
36+
%2 = arith.constant dense<64> : tensor<256x1xi32, #blocked>
37+
%3 = arith.muli %1, %2 : tensor<256x1xi32, #blocked>
38+
%4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
39+
%5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
40+
%6 = tt.broadcast %3 : tensor<256x1xi32, #blocked> -> tensor<256x64xi32, #blocked>
41+
%7 = tt.broadcast %5 : tensor<1x64xi32, #blocked> -> tensor<256x64xi32, #blocked>
42+
%8 = arith.addi %6, %7 : tensor<256x64xi32, #blocked>
43+
%9 = tt.splat %arg0 : !tt.ptr<i8> -> tensor<256x64x!tt.ptr<i8>, #blocked>
44+
%addr = tt.addptr %9, %8 : tensor<256x64x!tt.ptr<i8>, #blocked>, tensor<256x64xi32, #blocked>
45+
%cst = arith.constant dense<0> : tensor<256x64xi8, #blocked>
46+
// ALL-LAYOUT-COUNT-8: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 16, tile_width = 32, tile_height = 4, v_blocks = 1, cache_control = Default}
47+
tt.store %addr, %cst {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr<i8>, #blocked>
48+
49+
tt.return
50+
}
51+
}
52+
53+
// -----
54+
55+
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [4, 4], repCluster = [2, 2]}>
56+
#dot_a = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>
57+
module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
58+
// CHECK-LABEL: @regular_pointer_block_io
59+
tt.func public @regular_pointer_block_io(%arg0: !tt.ptr<i8>) {
60+
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #dot_a}>>
61+
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #dot_a}>> -> tensor<256x1xi32, #dot_a>
62+
%2 = arith.constant dense<64> : tensor<256x1xi32, #dot_a>
63+
%3 = arith.muli %1, %2 : tensor<256x1xi32, #dot_a>
64+
%4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #dot_a}>>
65+
%5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #dot_a}>> -> tensor<1x64xi32, #dot_a>
66+
%6 = tt.broadcast %3 : tensor<256x1xi32, #dot_a> -> tensor<256x64xi32, #dot_a>
67+
%7 = tt.broadcast %5 : tensor<1x64xi32, #dot_a> -> tensor<256x64xi32, #dot_a>
68+
%8 = arith.addi %6, %7 : tensor<256x64xi32, #dot_a>
69+
%9 = tt.splat %arg0 : !tt.ptr<i8> -> tensor<256x64x!tt.ptr<i8>, #dot_a>
70+
%addr = tt.addptr %9, %8 : tensor<256x64x!tt.ptr<i8>, #dot_a>, tensor<256x64xi32, #dot_a>
71+
%cst = arith.constant dense<0> : tensor<256x64xi8, #dot_a>
72+
// CHECK-COUNT-32: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
73+
tt.store %addr, %cst {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr<i8>, #dot_a>
74+
75+
tt.return
76+
}
77+
}
78+
79+
// -----
80+
81+
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 4], repCluster = [2, 2]}>
82+
#dot_a = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>
83+
module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
84+
// CHECK-LABEL: @regular_pointer_block_io
85+
tt.func public @regular_pointer_block_io(%arg0: !tt.ptr<f32>) {
86+
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #dot_a}>>
87+
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #dot_a}>> -> tensor<256x1xi32, #dot_a>
88+
%2 = arith.constant dense<64> : tensor<256x1xi32, #dot_a>
89+
%3 = arith.muli %1, %2 : tensor<256x1xi32, #dot_a>
90+
%4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #dot_a}>>
91+
%5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #dot_a}>> -> tensor<1x64xi32, #dot_a>
92+
%6 = tt.broadcast %3 : tensor<256x1xi32, #dot_a> -> tensor<256x64xi32, #dot_a>
93+
%7 = tt.broadcast %5 : tensor<1x64xi32, #dot_a> -> tensor<256x64xi32, #dot_a>
94+
%8 = arith.addi %6, %7 : tensor<256x64xi32, #dot_a>
95+
%9 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x64x!tt.ptr<f32>, #dot_a>
96+
%addr = tt.addptr %9, %8 : tensor<256x64x!tt.ptr<f32>, #dot_a>, tensor<256x64xi32, #dot_a>
97+
%cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #dot_a>
98+
// CHECK-COUNT-128: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 32, tile_width = 8, tile_height = 8, v_blocks = 1, cache_control = Default}
99+
tt.store %addr, %cst {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr<f32>, #dot_a>
100+
101+
tt.return
102+
}
103+
}
104+
105+
// -----
106+
107+
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 4], repCluster = [2, 2]}>
108+
#dot_b = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth = 1}>
109+
module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
110+
// CHECK-LABEL: @regular_pointer_block_io
111+
tt.func public @regular_pointer_block_io(%arg0: !tt.ptr<f32>) {
112+
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #dot_b}>>
113+
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #dot_b}>> -> tensor<256x1xi32, #dot_b>
114+
%2 = arith.constant dense<64> : tensor<256x1xi32, #dot_b>
115+
%3 = arith.muli %1, %2 : tensor<256x1xi32, #dot_b>
116+
%4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #dot_b}>>
117+
%5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #dot_b}>> -> tensor<1x64xi32, #dot_b>
118+
%6 = tt.broadcast %3 : tensor<256x1xi32, #dot_b> -> tensor<256x64xi32, #dot_b>
119+
%7 = tt.broadcast %5 : tensor<1x64xi32, #dot_b> -> tensor<256x64xi32, #dot_b>
120+
%8 = arith.addi %6, %7 : tensor<256x64xi32, #dot_b>
121+
%9 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x64x!tt.ptr<f32>, #dot_b>
122+
%addr = tt.addptr %9, %8 : tensor<256x64x!tt.ptr<f32>, #dot_b>, tensor<256x64xi32, #dot_b>
123+
%cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #dot_b>
124+
// CHECK-COUNT-128: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 32, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
125+
tt.store %addr, %cst {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr<f32>, #dot_b>
126+
127+
tt.return
128+
}
129+
}
130+
131+
// -----
132+
133+
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 4], repCluster = [2, 2]}>
134+
module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
135+
// CHECK-LABEL: @regular_pointer_block_io
136+
tt.func public @regular_pointer_block_io(%arg0: !tt.ptr<f32>) {
137+
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #dpas}>>
138+
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #dpas}>> -> tensor<256x1xi32, #dpas>
139+
%2 = arith.constant dense<64> : tensor<256x1xi32, #dpas>
140+
%3 = arith.muli %1, %2 : tensor<256x1xi32, #dpas>
141+
%4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #dpas}>>
142+
%5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #dpas}>> -> tensor<1x64xi32, #dpas>
143+
%6 = tt.broadcast %3 : tensor<256x1xi32, #dpas> -> tensor<256x64xi32, #dpas>
144+
%7 = tt.broadcast %5 : tensor<1x64xi32, #dpas> -> tensor<256x64xi32, #dpas>
145+
%8 = arith.addi %6, %7 : tensor<256x64xi32, #dpas>
146+
%9 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x64x!tt.ptr<f32>, #dpas>
147+
%addr = tt.addptr %9, %8 : tensor<256x64x!tt.ptr<f32>, #dpas>, tensor<256x64xi32, #dpas>
148+
%cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #dpas>
149+
// CHECK-COUNT-32: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 32, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
150+
tt.store %addr, %cst {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr<f32>, #dpas>
151+
152+
tt.return
153+
}
154+
}

0 commit comments

Comments
 (0)