Skip to content

Commit 16f5738

Browse files
authored
[XPU][Alloc] Optimize SLM allocation size for sub-group layout conversions (#2638)
Optimize shared memory allocation sizes for sub-group shuffle and transpose-like conversions: - Sub-group shuffle: Do not allocate memory at all. - Sub-group transpose: Allocate as much memory needed to store the whole tensor in SLM. --------- Signed-off-by: victor-eds <[email protected]>
1 parent 98dca47 commit 16f5738

File tree

3 files changed

+93
-6
lines changed

3 files changed

+93
-6
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory | FileCheck %s
2+
3+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
4+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
5+
6+
// Check no scratch memory is allocated for sub-group shuffle-like layout conversions.
7+
8+
// CHECK-LABEL: module attributes
9+
// CHECK-SAME: triton_gpu.shared = 0 : i32
10+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
11+
// CHECK: tt.func @test_sub_group_shuffle
12+
// CHECK-NOT: llvm.ptr<3>
13+
tt.func @test_sub_group_shuffle(%arg0: tensor<16xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<16xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> {
14+
%0 = triton_gpu.convert_layout %arg0 : tensor<16xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
15+
tt.return %0 : tensor<16xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
16+
}
17+
}
18+
19+
// -----
20+
21+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
22+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
23+
24+
// Check scracth memory configuration for different sub-group transpose-like layout conversions.
25+
26+
// CHECK-LABEL: module attributes
27+
// CHECK-SAME: triton_gpu.shared = 512 : i32
28+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
29+
tt.func @test_f16(%arg0: tensor<16x16xf16, #blocked>) -> tensor<16x16xf16, #blocked1> {
30+
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #blocked1>
31+
tt.return %0 : tensor<16x16xf16, #blocked1>
32+
}
33+
}
34+
35+
// -----
36+
37+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
38+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
39+
40+
// Check scracth memory configuration for different sub-group transpose-like layout conversions.
41+
42+
// CHECK-LABEL: module attributes
43+
// CHECK-SAME: triton_gpu.shared = 1024 : i32
44+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
45+
tt.func @test_f32(%arg0: tensor<16x16xf32, #blocked>) -> tensor<16x16xf32, #blocked1> {
46+
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #blocked1>
47+
tt.return %0 : tensor<16x16xf32, #blocked1>
48+
}
49+
}
50+
51+
// -----
52+
53+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 2], order = [0, 1]}>
54+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [4, 2], order = [0, 1]}>
55+
56+
// Check scracth memory configuration for different sub-group transpose-like layout conversions.
57+
58+
// CHECK-LABEL: module attributes
59+
// CHECK-SAME: triton_gpu.shared = 32768 : i32
60+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
61+
tt.func @test_f32(%arg0: tensor<128x64xf32, #blocked>) -> tensor<128x64xf32, #blocked1> {
62+
%0 = triton_gpu.convert_layout %arg0 : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #blocked1>
63+
tt.return %0 : tensor<128x64xf32, #blocked1>
64+
}
65+
}

test/Conversion/intel/sub-group-shuffle.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
1111
// CHECK-LABEL: llvm.func spir_kernelcc @test_f16(
12-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f16)>,
12+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f16)>)
1313
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f16)>
1414
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i32) : i32
1515
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]], %[[VAL_4]])
@@ -49,7 +49,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
4949
}
5050

5151
// CHECK-LABEL: llvm.func spir_kernelcc @test_bf16(
52-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(bf16)>,
52+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(bf16)>)
5353
// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(bf16)>
5454
// CHECK: %[[VAL_2:.*]] = llvm.bitcast %[[VAL_1]] : bf16 to i16
5555
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i32) : i32
@@ -91,7 +91,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
9191
}
9292

9393
// CHECK-LABEL: llvm.func spir_kernelcc @test_i1(
94-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(i1)>,
94+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(i1)>)
9595
// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(i1)>
9696
// CHECK: %[[VAL_2:.*]] = llvm.zext %[[VAL_1]] : i1 to i8
9797
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i32) : i32
@@ -133,7 +133,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
133133
}
134134

135135
// CHECK-LABEL: llvm.func spir_kernelcc @test_ptr(
136-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(ptr<1>)>,
136+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(ptr<1>)>)
137137
// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(ptr<1>)>
138138
// CHECK: %[[VAL_2:.*]] = llvm.ptrtoint %[[VAL_1]] : !llvm.ptr<1> to i64
139139
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i32) : i32
@@ -186,7 +186,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
186186

187187
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
188188
// CHECK-LABEL: llvm.func spir_kernelcc @test_f32(
189-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f32)>,
189+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f32)>)
190190
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f32)>
191191
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i32) : i32
192192
// CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_4]])
@@ -269,7 +269,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
269269

270270
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
271271
// CHECK-LABEL: llvm.func spir_kernelcc @test_non_sliced_multi_register(
272-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f64, f64)>,
272+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f64, f64)>)
273273
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f64, f64)>
274274
// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(f64, f64)>
275275
// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(0 : i32) : i32

third_party/intel/lib/Analysis/Allocation.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1616
#include "llvm/ADT/SmallVector.h"
1717

18+
#include "intel/include/Analysis/Utility.h"
19+
1820
using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
1921
using ::mlir::triton::gpu::BlockedEncodingAttr;
2022
using ::mlir::triton::gpu::DotOperandEncodingAttr;
@@ -104,6 +106,26 @@ static SmallVector<unsigned> getRepShapeForAtomic(Value result) {
104106

105107
ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
106108
RankedTensorType dstTy) {
109+
if (gpu::intel::cvtIsSubGroupShuffle(srcTy, dstTy)) {
110+
// Conversions that can be implemented as sub-group shuffles do not need
111+
// scratch memory.
112+
return ScratchConfig({}, {});
113+
}
114+
115+
if (gpu::intel::cvtIsSubGroupTranspose(srcTy, dstTy)) {
116+
// Conversions that can be implemented as sub-group transposes store the
117+
// whole tensor in shared memory and read it afterwards.
118+
auto srcEncoding = cast<gpu::DistributedEncodingTrait>(srcTy.getEncoding());
119+
unsigned threadsPerWarp = product(srcEncoding.getThreadsPerWarp());
120+
unsigned warpsPerCTA = product(srcEncoding.getWarpsPerCTA());
121+
unsigned remaining = product(srcTy.getShape()) /
122+
(threadsPerWarp * threadsPerWarp * warpsPerCTA);
123+
SmallVector<unsigned> repShape{threadsPerWarp, threadsPerWarp, remaining,
124+
warpsPerCTA};
125+
return ScratchConfig(repShape, repShape,
126+
/*inVec=*/1, /*outVec=*/threadsPerWarp);
127+
}
128+
107129
// Initialize vector sizes and stride
108130
auto repShape = getRepShapeForCvt(srcTy, dstTy);
109131
if (repShape.empty())

0 commit comments

Comments
 (0)