Skip to content

Commit 290bfa9

Browse files
authored
[XPU][TritonIntelGPUToLLVM] Make sub-group CVT checks public (#2629)
Expose `cvtIsSubGroupShuffle` and `cvtIsSubGroupTranspose` as API functions to be able to reuse them in an Intel-specific memory allocation analysis. Signed-off-by: victor-eds <[email protected]>
1 parent 0c51766 commit 290bfa9

File tree

3 files changed

+217
-190
lines changed

3 files changed

+217
-190
lines changed

third_party/intel/include/Analysis/Utility.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@ namespace mlir::triton::gpu::intel {
77

88
bool isDpasToDotShortcut(RankedTensorType dpasTy, RankedTensorType dotTy);
99

10+
/// Return whether the layout conversion from `srcTy` to `dstTy` can be
11+
/// performed as a sub-group shuffle.
12+
bool cvtIsSubGroupShuffle(RankedTensorType srcTy, RankedTensorType dstTy);
13+
/// Return whether the layout conversion from `srcTy` to `dstTy` can be
14+
/// performed as a sub-group transpose through local memory.
15+
bool cvtIsSubGroupTranspose(RankedTensorType srcTy, RankedTensorType dstTy);
16+
/// Return whether `type` is a valid element type for a fast sub-group
17+
/// transpose.
18+
bool isValidElementTypeForSubGroupTranspose(Type type);
19+
1020
} // namespace mlir::triton::gpu::intel
1121

1222
#endif // TRITON_INTEL_ANALYSIS_UTILITY_H

third_party/intel/lib/Analysis/Utility.cpp

Lines changed: 200 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,91 @@
1-
#include "triton/Analysis/Utility.h"
21
#include "intel/include/Analysis/Utility.h"
2+
3+
#include "llvm/ADT/TypeSwitch.h"
4+
5+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
6+
37
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
48

59
namespace mlir::triton::gpu::intel {
10+
namespace {
11+
constexpr inline unsigned minSubGroupTransposeWidth = 8;
12+
13+
bool canTypeBeConvertedForSubGroupTranspose(Type type) {
14+
return TypeSwitch<Type, bool>(type)
15+
.Case([](FloatType floatTy) {
16+
// Support via bitcasting to integer type.
17+
return isValidElementTypeForSubGroupTranspose(
18+
IntegerType::get(floatTy.getContext(), floatTy.getWidth()));
19+
})
20+
.Case([](IntegerType intTy) {
21+
// Support via extending to supported type.
22+
return isValidElementTypeForSubGroupTranspose(intTy) ||
23+
intTy.getWidth() < minSubGroupTransposeWidth;
24+
})
25+
.Case([](PointerType) {
26+
// Support via ptrtoint
27+
return true;
28+
})
29+
.Default(false);
30+
}
31+
32+
// Return a vector such as:
33+
// [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [laneSize, 0], ...,
34+
// [registerSize / 2, 0]],
35+
// i.e., mapping registers to lanes till laneSize and performing an ID
36+
// conversion afterwards.
37+
std::vector<std::vector<int32_t>>
38+
buildSubGroupTransposeRegisterBases(int32_t registerSize, int32_t laneSize) {
39+
std::vector<std::vector<int32_t>> bases;
40+
std::vector<int32_t> curr(2);
41+
for (int32_t i = 1; i < laneSize; i *= 2) {
42+
curr[1] = i;
43+
bases.push_back(curr);
44+
}
45+
curr[1] = 0;
46+
for (int32_t i = laneSize; i < registerSize; i *= 2) {
47+
curr[0] = i;
48+
bases.push_back(curr);
49+
}
50+
return bases;
51+
}
52+
53+
// Return a vector such as:
54+
// [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [1, 0], ...,
55+
// [registerSize / (2 * laneSize), 0]]
56+
// i.e., mapping registers to lanes till laneSize and repeating the pattern
57+
// afterwards.
58+
std::vector<std::vector<int32_t>>
59+
buildSubGroupShuffleRegisterBases(int32_t registerSize, int32_t laneSize) {
60+
std::vector<std::vector<int32_t>> bases;
61+
std::vector<int32_t> curr(2);
62+
for (int32_t i = 1; i < laneSize; i *= 2) {
63+
curr[1] = i;
64+
bases.push_back(curr);
65+
}
66+
curr[1] = 0;
67+
for (int32_t i = laneSize, val = 1; i < registerSize; i *= 2, val *= 2) {
68+
curr[0] = val;
69+
bases.push_back(curr);
70+
}
71+
return bases;
72+
}
73+
74+
// Return a vector such as:
75+
// [[1, 0], [2, 0], [4, 0], ..., [laneSize / 2, 0]],
76+
// i.e., mapping lanes to registers.
77+
std::vector<std::vector<int32_t>>
78+
buildSubGroupTransposeLaneBases(int32_t laneSize) {
79+
std::vector<std::vector<int32_t>> bases;
80+
std::vector<int32_t> curr(2);
81+
for (int32_t i = 1; i < laneSize; i *= 2) {
82+
curr[0] = i;
83+
bases.push_back(curr);
84+
}
85+
return bases;
86+
}
87+
88+
} // namespace
689

790
bool isDpasToDotShortcut(RankedTensorType dpasTy, RankedTensorType dotTy) {
891
auto dpasLayout = dyn_cast<DpasEncodingAttr>(dpasTy.getEncoding());
@@ -24,4 +107,120 @@ bool isDpasToDotShortcut(RankedTensorType dpasTy, RankedTensorType dotTy) {
24107
return false;
25108
}
26109

110+
bool cvtIsSubGroupShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
111+
MLIRContext *ctx = srcTy.getContext();
112+
StringAttr kRegister = str_attr("register");
113+
StringAttr kLane = str_attr("lane");
114+
StringAttr kWarp = str_attr("warp");
115+
StringAttr kBlock = str_attr("block");
116+
117+
std::optional<LinearLayout> srcLayout =
118+
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
119+
if (!srcLayout)
120+
return false;
121+
122+
std::optional<LinearLayout> dstLayout =
123+
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
124+
if (!dstLayout)
125+
return false;
126+
127+
LinearLayout comp = dstLayout->invertAndCompose(*srcLayout);
128+
std::optional<LinearLayout> conversion = comp.quotient(kBlock);
129+
if (!conversion)
130+
return false;
131+
conversion = conversion->quotient(kWarp);
132+
if (!conversion)
133+
return false;
134+
135+
// TODO: Support more kind of shuffles.
136+
// Expected conversion is:
137+
// - register=1 -> (0, 1)
138+
// ...
139+
// - register=2**i -> (0, 2**i)
140+
// ...
141+
// - register=M -> (0, 2**M)
142+
// ...
143+
// - register=2**k -> (2**(k-M), 0)
144+
// ...
145+
// - register=2**N -> (2**(N-M), 0)
146+
// - lane=1 -> (0, 0)
147+
// ...
148+
// - lane=2**j -> (0, 0)
149+
// ...
150+
// lane=2**M -> (0, 0)
151+
// where out dims are: [register (size 2**(N - M)), lane (size 2**(M + 1))]
152+
//
153+
// With N >= M.
154+
int32_t registerInDimSize = conversion->getInDimSize(kRegister);
155+
int32_t laneOutDimSize = conversion->getOutDimSize(kLane);
156+
return conversion->sublayoutIsZero({kLane}, {kRegister, kLane}) &&
157+
conversion->getBases().lookup(kRegister) ==
158+
buildSubGroupShuffleRegisterBases(registerInDimSize,
159+
laneOutDimSize);
160+
}
161+
162+
bool isValidElementTypeForSubGroupTranspose(Type type) {
163+
return TypeSwitch<Type, bool>(type)
164+
.Case([](IntegerType intTy) {
165+
unsigned width = intTy.getWidth();
166+
return width == 8 || width == 16 || width == 32 || width == 64;
167+
})
168+
.Default(false);
169+
}
170+
171+
bool cvtIsSubGroupTranspose(RankedTensorType srcTy, RankedTensorType dstTy) {
172+
if (!canTypeBeConvertedForSubGroupTranspose(srcTy.getElementType()))
173+
return false;
174+
175+
MLIRContext *ctx = srcTy.getContext();
176+
StringAttr kRegister = str_attr("register");
177+
StringAttr kLane = str_attr("lane");
178+
StringAttr kWarp = str_attr("warp");
179+
StringAttr kBlock = str_attr("block");
180+
181+
std::optional<LinearLayout> srcLayout =
182+
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
183+
if (!srcLayout)
184+
return false;
185+
186+
std::optional<LinearLayout> dstLayout =
187+
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
188+
if (!dstLayout)
189+
return false;
190+
191+
LinearLayout comp = dstLayout->invertAndCompose(*srcLayout);
192+
std::optional<LinearLayout> conversion = comp.quotient(kBlock);
193+
if (!conversion)
194+
return false;
195+
conversion = conversion->quotient(kWarp);
196+
if (!conversion)
197+
return false;
198+
199+
// Expected conversion is:
200+
// - register=1 -> (0, 1)
201+
// ...
202+
// - register=2**i -> (0, 2**i)
203+
// ...
204+
// - register=M -> (0, 2**M)
205+
// ...
206+
// - register=2**k -> (2**k, 0)
207+
// ...
208+
// - register=N -> (2**N, 0)
209+
// - lane=1 -> (0, 1)
210+
// ...
211+
// - lane=2**j -> (2**j, 0)
212+
// ...
213+
// lane=2**M -> (2**M, 0)
214+
// where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))]
215+
//
216+
// With N >= M.
217+
int32_t registerInDimSize = conversion->getInDimSize(kRegister);
218+
int32_t laneInDimSize = conversion->getInDimSize(kLane);
219+
return conversion->getBases().lookup(kRegister) ==
220+
buildSubGroupTransposeRegisterBases(registerInDimSize,
221+
laneInDimSize) &&
222+
conversion->getBases().lookup(kLane) ==
223+
buildSubGroupTransposeLaneBases(laneInDimSize);
224+
}
225+
27226
} // namespace mlir::triton::gpu::intel

0 commit comments

Comments
 (0)