Skip to content

Commit fa40e4c

Browse files
authored
[LoadStoreOpToLLVM] Improve the getBlockIOTileSize to support DotOp layout (#4704)
Improve the getBlockIOTileSize to support DotOp layout by stacking the name in incremental order instead of sequential continuous. Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 877cf7a commit fa40e4c

File tree

4 files changed

+398
-118
lines changed

4 files changed

+398
-118
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
4747
"TRITON_INTEL_ADVANCED_PATH",
4848
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
4949
"TRITON_INTEL_DO_NOT_SINK_INSTR_ACROSS_RGN",
50+
"TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS",
5051
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",
5152
"TRITON_INTEL_ENABLE_INSTR_SCHED",
5253
"TRITON_INTEL_FAST_MATH",

python/test/unit/intel/test_block_store.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import itertools
23

34
import numpy as np
@@ -8,6 +9,8 @@
89
import triton
910
from triton._internal_testing import is_xpu
1011

12+
os.environ["TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS"] = "1"
13+
1114

1215
class DpasLayout:
1316

@@ -184,3 +187,6 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl
184187

185188
kernel[(1, 1, 1)](a, x)
186189
assert torch.equal(a, x)
190+
191+
if support_block_io and block_ptr:
192+
assert 'spirv_Subgroup2DBlockStoreINTEL' in kernel.asm['llir'] or 'GenISA.LSC2DBlockWrite' in kernel.asm['llir']

test/TritonIntelGPU/blockptr_store.mlir

Lines changed: 147 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,148 @@
11
// RUN: triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
2+
// RUN: env TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS=1 triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=ALL-LAYOUT
3+
4+
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 4], repCluster = [2, 2]}>
5+
#dot_a = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>
6+
module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32, "ttg.threads-per-warp" = 16 : i32} {
7+
// CHECK-LABEL: llvm.func spir_kernelcc @dot_a_layout
8+
tt.func public @dot_a_layout(%arg0: !tt.ptr<i8>, %col_stride: i64) {
9+
%cst = arith.constant dense<0> : tensor<256x64xi8, #dot_a>
10+
%c64_i64 = arith.constant 64 : i64
11+
%c1_i64 = arith.constant 1 : i64
12+
%c0_i32 = arith.constant 0 : i32
13+
%0 = tt.make_tensor_ptr %arg0, [%c64_i64, %c64_i64], [%c1_i64, %col_stride], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<256x64xi8, #dot_a>>
14+
// ALL-LAYOUT: %[[OFF_0:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
15+
// ALL-LAYOUT: %[[OFF_1:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
16+
// ALL-LAYOUT: %[[HEIGHT_i64:.*]] = llvm.extractvalue {{.*}}[2] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
17+
// ALL-LAYOUT: %[[WIDTH_i64:.*]] = llvm.extractvalue {{.*}}[3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
18+
// ALL-LAYOUT: %[[ROW_STRIDE_i64:.*]] = llvm.extractvalue {{.*}}[4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
19+
// ALL-LAYOUT: %[[COL_STRIDE_i64:.*]] = llvm.extractvalue {{.*}}[5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
20+
// ALL-LAYOUT: %[[BASE_PTR:.*]] = llvm.extractvalue {{.*}}[6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
21+
22+
// ALL-LAYOUT: %[[HEIGHT:.*]] = llvm.trunc %[[HEIGHT_i64]] : i64 to i32
23+
24+
// ALL-LAYOUT: %[[OFFSET:.*]] = llvm.add %[[OFF_0]], {{.*}} : i32
25+
// ALL-LAYOUT: %[[BASE:.*]] = llvm.getelementptr %[[BASE_PTR]]{{.*}} : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, i8
26+
// ALL-LAYOUT: %[[OFFSET_X:.*]] = llvm.mlir.constant(0 : i32) : i32
27+
// ALL-LAYOUT: %[[OFFSET_Y:.*]] = llvm.select {{.*}}, %[[OFFSET]], %[[HEIGHT]] : i1, i32
28+
// ALL-LAYOUT: llvm.mlir.undef : vector<4xi8>
29+
// ALL-LAYOUT-COUNT-4: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<4xi8>
30+
// ALL-LAYOUT: triton_gen.2Dblockstore {{.*}}, %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 8, tile_width = 8, tile_height = 8, v_blocks = 1, cache_control = Default}
31+
tt.store %0, %cst {ttig.block_io = "row_major", boundaryCheck = array<i32: 0>} : !tt.ptr<tensor<256x64xi8, #dot_a>>
32+
// ALL-LAYOUT-COUNT-63: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 8, tile_width = 8, tile_height = 8, v_blocks = 1, cache_control = Default}
33+
34+
tt.return
35+
}
36+
}
37+
38+
// -----
39+
40+
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 4], repCluster = [2, 2]}>
41+
#dot_b = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth = 1}>
42+
module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32, "ttg.threads-per-warp" = 16 : i32} {
43+
// CHECK-LABEL: llvm.func spir_kernelcc @dot_b_layout
44+
tt.func public @dot_b_layout(%arg0: !tt.ptr<i8>, %col_stride: i64) {
45+
%cst = arith.constant dense<0> : tensor<256x64xi8, #dot_b>
46+
%c64_i64 = arith.constant 64 : i64
47+
%c1_i64 = arith.constant 1 : i64
48+
%c0_i32 = arith.constant 0 : i32
49+
%0 = tt.make_tensor_ptr %arg0, [%c64_i64, %c64_i64], [%c1_i64, %col_stride], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<256x64xi8, #dot_b>>
50+
// ALL-LAYOUT: %[[OFF_0:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
51+
// ALL-LAYOUT: %[[OFF_1:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
52+
// ALL-LAYOUT: %[[HEIGHT_i64:.*]] = llvm.extractvalue {{.*}}[2] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
53+
// ALL-LAYOUT: %[[WIDTH_i64:.*]] = llvm.extractvalue {{.*}}[3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
54+
// ALL-LAYOUT: %[[ROW_STRIDE_i64:.*]] = llvm.extractvalue {{.*}}[4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
55+
// ALL-LAYOUT: %[[COL_STRIDE_i64:.*]] = llvm.extractvalue {{.*}}[5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
56+
// ALL-LAYOUT: %[[BASE_PTR:.*]] = llvm.extractvalue {{.*}}[6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
57+
58+
// ALL-LAYOUT: %[[HEIGHT:.*]] = llvm.trunc %[[HEIGHT_i64]] : i64 to i32
59+
60+
// ALL-LAYOUT: %[[OFFSET:.*]] = llvm.add %[[OFF_0]], {{.*}} : i32
61+
// ALL-LAYOUT: %[[BASE:.*]] = llvm.getelementptr %[[BASE_PTR]]{{.*}} : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, i8
62+
// ALL-LAYOUT: %[[OFFSET_X:.*]] = llvm.mlir.constant(0 : i32) : i32
63+
// ALL-LAYOUT: %[[OFFSET_Y:.*]] = llvm.select {{.*}}, %[[OFFSET]], %[[HEIGHT]] : i1, i32
64+
// ALL-LAYOUT: llvm.mlir.undef : vector<8xi8>
65+
// ALL-LAYOUT-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xi8>
66+
// ALL-LAYOUT: triton_gen.2Dblockstore {{.*}}, %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 8, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
67+
tt.store %0, %cst {ttig.block_io = "row_major", boundaryCheck = array<i32: 0>} : !tt.ptr<tensor<256x64xi8, #dot_b>>
68+
// ALL-LAYOUT-COUNT-63: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 8, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
69+
70+
tt.return
71+
}
72+
}
73+
74+
// -----
75+
76+
#blocked = #ttg.blocked<{sizePerThread = [1, 4, 2], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 8, 2], order = [2, 1, 0]}>
77+
#slice = #ttg.slice<{dim = 1, parent = #blocked}>
78+
module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32, "ttg.threads-per-warp" = 32 : i32} {
79+
// CHECK-LABEL: llvm.func spir_kernelcc @slice_layout
80+
tt.func public @slice_layout(%arg0: !tt.ptr<i8>, %col_stride: i64) {
81+
%cst = arith.constant dense<0> : tensor<256x64xi8, #slice>
82+
%c64_i64 = arith.constant 64 : i64
83+
%c1_i64 = arith.constant 1 : i64
84+
%c0_i32 = arith.constant 0 : i32
85+
%0 = tt.make_tensor_ptr %arg0, [%c64_i64, %c64_i64], [%c1_i64, %col_stride], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<256x64xi8, #slice>>
86+
// ALL-LAYOUT: %[[OFF_0:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
87+
// ALL-LAYOUT: %[[OFF_1:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
88+
// ALL-LAYOUT: %[[HEIGHT_i64:.*]] = llvm.extractvalue {{.*}}[2] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
89+
// ALL-LAYOUT: %[[WIDTH_i64:.*]] = llvm.extractvalue {{.*}}[3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
90+
// ALL-LAYOUT: %[[ROW_STRIDE_i64:.*]] = llvm.extractvalue {{.*}}[4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
91+
// ALL-LAYOUT: %[[COL_STRIDE_i64:.*]] = llvm.extractvalue {{.*}}[5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
92+
// ALL-LAYOUT: %[[BASE_PTR:.*]] = llvm.extractvalue {{.*}}[6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
93+
94+
// ALL-LAYOUT: %[[HEIGHT:.*]] = llvm.trunc %[[HEIGHT_i64]] : i64 to i32
95+
96+
// ALL-LAYOUT: %[[OFFSET:.*]] = llvm.add %[[OFF_0]], {{.*}} : i32
97+
// ALL-LAYOUT: %[[BASE:.*]] = llvm.getelementptr %[[BASE_PTR]]{{.*}} : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, i8
98+
// ALL-LAYOUT: %[[OFFSET_X:.*]] = llvm.mlir.constant(0 : i32) : i32
99+
// ALL-LAYOUT: %[[OFFSET_Y:.*]] = llvm.select {{.*}}, %[[OFFSET]], %[[HEIGHT]] : i1, i32
100+
// ALL-LAYOUT: llvm.mlir.undef : vector<16xi8>
101+
// ALL-LAYOUT-COUNT-16: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<16xi8>
102+
// ALL-LAYOUT: triton_gen.2Dblockstore {{.*}}, %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 16, tile_width = 32, tile_height = 8, v_blocks = 1, cache_control = Default}
103+
tt.store %0, %cst {ttig.block_io = "row_major", boundaryCheck = array<i32: 0>} : !tt.ptr<tensor<256x64xi8, #slice>>
104+
// ALL-LAYOUT-COUNT-31: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 16, tile_width = 32, tile_height = 8, v_blocks = 1, cache_control = Default}
105+
106+
tt.return
107+
}
108+
}
109+
110+
// -----
111+
112+
#blocked = #ttg.blocked<{sizePerThread = [4, 2], threadsPerWarp = [1, 32], warpsPerCTA = [8, 2], order = [1, 0]}>
113+
module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32, "ttg.threads-per-warp" = 32 : i32} {
114+
// CHECK-LABEL: llvm.func spir_kernelcc @block_layout
115+
tt.func public @block_layout(%arg0: !tt.ptr<i8>, %col_stride: i64) {
116+
%cst = arith.constant dense<0> : tensor<256x64xi8, #blocked>
117+
%c64_i64 = arith.constant 64 : i64
118+
%c1_i64 = arith.constant 1 : i64
119+
%c0_i32 = arith.constant 0 : i32
120+
%0 = tt.make_tensor_ptr %arg0, [%c64_i64, %c64_i64], [%c1_i64, %col_stride], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<256x64xi8, #blocked>>
121+
// ALL-LAYOUT: %[[OFF_0:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
122+
// ALL-LAYOUT: %[[OFF_1:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
123+
// ALL-LAYOUT: %[[HEIGHT_i64:.*]] = llvm.extractvalue {{.*}}[2] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
124+
// ALL-LAYOUT: %[[WIDTH_i64:.*]] = llvm.extractvalue {{.*}}[3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
125+
// ALL-LAYOUT: %[[ROW_STRIDE_i64:.*]] = llvm.extractvalue {{.*}}[4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
126+
// ALL-LAYOUT: %[[COL_STRIDE_i64:.*]] = llvm.extractvalue {{.*}}[5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
127+
// ALL-LAYOUT: %[[BASE_PTR:.*]] = llvm.extractvalue {{.*}}[6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
128+
129+
// ALL-LAYOUT: %[[HEIGHT:.*]] = llvm.trunc %[[HEIGHT_i64]] : i64 to i32
130+
131+
// ALL-LAYOUT: %[[OFFSET:.*]] = llvm.add %[[OFF_0]], {{.*}} : i32
132+
// ALL-LAYOUT: %[[BASE:.*]] = llvm.getelementptr %[[BASE_PTR]]{{.*}} : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, i8
133+
// ALL-LAYOUT: %[[OFFSET_X:.*]] = llvm.mlir.constant(0 : i32) : i32
134+
// ALL-LAYOUT: %[[OFFSET_Y:.*]] = llvm.select {{.*}}, %[[OFFSET]], %[[HEIGHT]] : i1, i32
135+
// ALL-LAYOUT: llvm.mlir.undef : vector<8xi8>
136+
// ALL-LAYOUT-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xi8>
137+
// ALL-LAYOUT: triton_gen.2Dblockstore {{.*}}, %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 16, tile_width = 32, tile_height = 4, v_blocks = 1, cache_control = Default}
138+
tt.store %0, %cst {ttig.block_io = "row_major", boundaryCheck = array<i32: 0>} : !tt.ptr<tensor<256x64xi8, #blocked>>
139+
// ALL-LAYOUT-COUNT-7: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 16, tile_width = 32, tile_height = 4, v_blocks = 1, cache_control = Default}
140+
141+
tt.return
142+
}
143+
}
144+
145+
// -----
2146

3147
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
4148
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
@@ -72,7 +216,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32,
72216

73217
// COM: When boundary check is absent:
74218
// CHECK: %[[baseWidth:.*]] = llvm.mlir.constant(64 : i32)
75-
// CHECK: %[[base1:.*]] = llvm.getelementptr %[[base]][%[[OFFSET_X]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, i16
219+
// CHECK: %[[base1:.*]] = llvm.getelementptr %[[base]][%[[OFFSET_X]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f16
76220
// CHECK: %[[OFFSET_X:.*]] = llvm.mlir.constant(0 : i32) : i32
77221
// CHECK: %[[baseHeight:.*]] = llvm.mlir.constant(8 : i32)
78222
// CHECK: %[[OFF:.*]] = llvm.mul %[[OFFSET_Y]], %[[PITCH]] : i32
@@ -141,7 +285,8 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
141285
// COM: [6, 7]
142286

143287
// COM: replica [0, 0]
144-
// CHECK-COUNT-3: llvm.mlir.constant(0 : i32) : i32
288+
// CHECK: llvm.call spir_funccc @_Z12get_local_idj
289+
// CHECK-COUNT-4: llvm.mlir.constant(0 : i32) : i32
145290
// CHECK: %[[VAL_186:.*]] = llvm.mlir.constant(0 : i32) : i32
146291
// CHECK: %[[OFFSET_X:.*]] = llvm.add %[[OFF_1]], %[[VAL_186]] : i32
147292
// CHECK: %[[OFFSET_Y:.*]] = llvm.add %[[OFF_0]], %[[VAL_186]] : i32

0 commit comments

Comments
 (0)