Skip to content

Commit 6c5b353

Browse files
committed
Drop barriers
1 parent 966d6f1 commit 6c5b353

File tree

5 files changed

+248
-31
lines changed

5 files changed

+248
-31
lines changed

third_party/intel/include/Analysis/Utility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,13 @@ bool isDpasToDotShortcut(RankedTensorType dpasTy, RankedTensorType dotTy);
1010
/// Return whether the layout conversion from `srcTy` to `dstTy` can be
1111
/// performed as a sub-group shuffle.
1212
bool cvtIsSubGroupShuffle(RankedTensorType srcTy, RankedTensorType dstTy);
13+
bool cvtIsContiguousSubGroupShuffle(RankedTensorType srcTy,
14+
RankedTensorType dstTy);
1315
/// Return whether the layout conversion from `srcTy` to `dstTy` can be
1416
/// performed as a sub-group transpose through local memory.
1517
bool cvtIsSubGroupTranspose(RankedTensorType srcTy, RankedTensorType dstTy);
18+
bool cvtIsContiguousSubGroupTranspose(RankedTensorType srcTy,
19+
RankedTensorType dstTy);
1620
/// Return whether `type` is a valid element type for a fast sub-group
1721
/// transpose.
1822
bool isValidElementTypeForSubGroupTranspose(Type type);

third_party/intel/lib/Analysis/Allocation.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ static SmallVector<unsigned> getRepShapeForAtomic(Value result) {
106106

107107
ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
108108
RankedTensorType dstTy) {
109-
if (gpu::intel::cvtIsSubGroupShuffle(srcTy, dstTy)) {
109+
if (gpu::intel::cvtIsSubGroupShuffle(srcTy, dstTy) ||
110+
gpu::intel::cvtIsContiguousSubGroupShuffle(srcTy, dstTy)) {
110111
// Conversions that can be implemented as sub-group shuffles do not need
111112
// scratch memory.
112113
return ScratchConfig({}, {});
@@ -117,7 +118,8 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
117118
return ScratchConfig({}, {});
118119
}
119120

120-
if (gpu::intel::cvtIsSubGroupTranspose(srcTy, dstTy)) {
121+
if (gpu::intel::cvtIsSubGroupTranspose(srcTy, dstTy) ||
122+
gpu::intel::cvtIsContiguousSubGroupTranspose(srcTy, dstTy)) {
121123
// Conversions that can be implemented as sub-group transposes store the
122124
// whole tensor in shared memory and read it afterwards.
123125
auto srcEncoding = cast<gpu::DistributedEncodingTrait>(srcTy.getEncoding());

third_party/intel/lib/Analysis/Membar.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ triton::gpu::ConvertLayoutOp dynCastToSubGroupTranspose(Operation *op) {
1010
return nullptr;
1111

1212
if (!triton::gpu::intel::cvtIsSubGroupTranspose(
13+
convertLayout.getSrc().getType(),
14+
convertLayout.getResult().getType()) &&
15+
!triton::gpu::intel::cvtIsContiguousSubGroupTranspose(
1316
convertLayout.getSrc().getType(),
1417
convertLayout.getResult().getType()))
1518
return nullptr;

third_party/intel/lib/Analysis/Utility.cpp

Lines changed: 173 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,28 @@ buildSubGroupShuffleRegisterBases(int32_t registerSize, int32_t laneSize) {
7171
return bases;
7272
}
7373

74+
// Return a vector such as:
75+
// [[1, 0], ..., [registerSize / (laneSize * 2)], [0, 1], [0, 2], [0, 4], ...,
76+
// [0, laneSize / 2]] i.e., mapping registers to lanes till laneSize and
77+
// repeating the pattern afterwards.
78+
std::vector<std::vector<int32_t>>
79+
buildContiguousSubGroupShuffleRegisterBases(int32_t registerSize,
80+
int32_t laneSize) {
81+
std::vector<std::vector<int32_t>> bases;
82+
std::vector<int32_t> curr(2);
83+
int i = 1;
84+
for (; i < registerSize / laneSize; i *= 2) {
85+
curr[0] = i;
86+
bases.push_back(curr);
87+
}
88+
curr[0] = 0;
89+
for (int32_t val = 1; i < registerSize; i *= 2, val *= 2) {
90+
curr[1] = val;
91+
bases.push_back(curr);
92+
}
93+
return bases;
94+
}
95+
7496
// Return a vector such as:
7597
// [[1, 0], [2, 0], [4, 0], ..., [laneSize / 2, 0]],
7698
// i.e., mapping lanes to registers.
@@ -85,6 +107,46 @@ buildSubGroupTransposeLaneBases(int32_t laneSize) {
85107
return bases;
86108
}
87109

110+
// Return a vector such as:
111+
// [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [1, 0], ...,
112+
// [registerSize / (laneSize * 2), 0]],
113+
// i.e., mapping registers to lanes till laneSize and performing an ID
114+
// conversion afterwards.
115+
std::vector<std::vector<int32_t>>
116+
buildContiguousSubGroupTransposeRegisterBases(int32_t registerSize,
117+
int32_t laneSize) {
118+
std::vector<std::vector<int32_t>> bases;
119+
std::vector<int32_t> curr(2);
120+
int i = 1;
121+
for (; i < laneSize; i *= 2) {
122+
curr[1] = i;
123+
bases.push_back(curr);
124+
}
125+
curr[1] = 0;
126+
for (int32_t j = 1; i < registerSize; i *= 2, j *= 2) {
127+
curr[0] = j;
128+
bases.push_back(curr);
129+
}
130+
return bases;
131+
}
132+
133+
// Return a vector such as:
134+
// [[registerSize / laneSize, 0], [registerSize / laneSize * 2, 0], ...,
135+
// [registerSize / 2, 0]]
136+
// i.e., mapping registers to lanes till laneSize and performing an ID
137+
// conversion afterwards.
138+
std::vector<std::vector<int32_t>>
139+
buildContiguousSubGroupTransposeLaneBases(int32_t registerSize,
140+
int32_t laneSize) {
141+
std::vector<std::vector<int32_t>> bases;
142+
std::vector<int32_t> curr(2);
143+
for (int32_t i = registerSize / laneSize; i < registerSize; i *= 2) {
144+
curr[0] = i;
145+
bases.push_back(curr);
146+
}
147+
return bases;
148+
}
149+
88150
} // namespace
89151

90152
bool isDpasToDotShortcut(RankedTensorType dpasTy, RankedTensorType dotTy) {
@@ -159,6 +221,59 @@ bool cvtIsSubGroupShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
159221
laneOutDimSize);
160222
}
161223

224+
bool cvtIsContiguousSubGroupShuffle(RankedTensorType srcTy,
225+
RankedTensorType dstTy) {
226+
MLIRContext *ctx = srcTy.getContext();
227+
StringAttr kRegister = str_attr("register");
228+
StringAttr kLane = str_attr("lane");
229+
StringAttr kWarp = str_attr("warp");
230+
StringAttr kBlock = str_attr("block");
231+
232+
std::optional<LinearLayout> srcLayout =
233+
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
234+
if (!srcLayout)
235+
return false;
236+
237+
std::optional<LinearLayout> dstLayout =
238+
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
239+
if (!dstLayout)
240+
return false;
241+
242+
LinearLayout comp = dstLayout->invertAndCompose(*srcLayout);
243+
std::optional<LinearLayout> conversion = comp.quotient(kBlock);
244+
if (!conversion)
245+
return false;
246+
conversion = conversion->quotient(kWarp);
247+
if (!conversion)
248+
return false;
249+
250+
// TODO: Support more kind of shuffles.
251+
// Expected conversion is:
252+
// - register=1 -> (0, 1)
253+
// ...
254+
// - register=2**i -> (0, 2**i)
255+
// ...
256+
// - register=M -> (0, 2**M)
257+
// ...
258+
// - register=2**k -> (2**(k-M), 0)
259+
// ...
260+
// - register=2**N -> (2**(N-M), 0)
261+
// - lane=1 -> (0, 0)
262+
// ...
263+
// - lane=2**j -> (0, 0)
264+
// ...
265+
// lane=2**M -> (0, 0)
266+
// where out dims are: [register (size 2**(N - M)), lane (size 2**(M + 1))]
267+
//
268+
// With N >= M.
269+
int32_t registerInDimSize = conversion->getInDimSize(kRegister);
270+
int32_t laneOutDimSize = conversion->getOutDimSize(kLane);
271+
return conversion->sublayoutIsZero({kLane}, {kRegister, kLane}) &&
272+
conversion->getBases().lookup(kRegister) ==
273+
buildContiguousSubGroupShuffleRegisterBases(registerInDimSize,
274+
laneOutDimSize);
275+
}
276+
162277
bool isValidElementTypeForSubGroupTranspose(Type type) {
163278
return TypeSwitch<Type, bool>(type)
164279
.Case([](IntegerType intTy) {
@@ -196,16 +311,14 @@ bool cvtIsSubGroupTranspose(RankedTensorType srcTy, RankedTensorType dstTy) {
196311
if (!conversion)
197312
return false;
198313

199-
llvm::errs() << conversion << "\n";
200-
201314
// Expected conversion is:
202315
// - register=1 -> (0, 1)
203316
// ...
204317
// - register=2**i -> (0, 2**i)
205318
// ...
206319
// - register=M -> (0, 2**M)
207320
// ...
208-
// - register=2**k -> (2**k, 0)
321+
// - register=2**k -> (, 0)
209322
// ...
210323
// - register=N -> (2**N, 0)
211324
// - lane=1 -> (0, 1)
@@ -225,6 +338,63 @@ bool cvtIsSubGroupTranspose(RankedTensorType srcTy, RankedTensorType dstTy) {
225338
buildSubGroupTransposeLaneBases(laneInDimSize);
226339
}
227340

341+
bool cvtIsContiguousSubGroupTranspose(RankedTensorType srcTy,
342+
RankedTensorType dstTy) {
343+
if (!canTypeBeConvertedForSubGroupTranspose(srcTy.getElementType()))
344+
return false;
345+
346+
MLIRContext *ctx = srcTy.getContext();
347+
StringAttr kRegister = str_attr("register");
348+
StringAttr kLane = str_attr("lane");
349+
StringAttr kWarp = str_attr("warp");
350+
StringAttr kBlock = str_attr("block");
351+
352+
std::optional<LinearLayout> srcLayout =
353+
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
354+
if (!srcLayout)
355+
return false;
356+
357+
std::optional<LinearLayout> dstLayout =
358+
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
359+
if (!dstLayout)
360+
return false;
361+
362+
LinearLayout comp = dstLayout->invertAndCompose(*srcLayout);
363+
std::optional<LinearLayout> conversion = comp.quotient(kBlock);
364+
if (!conversion)
365+
return false;
366+
conversion = conversion->quotient(kWarp);
367+
if (!conversion)
368+
return false;
369+
370+
// Expected conversion is:
371+
// - register=1 -> (0, 1)
372+
// ...
373+
// - register=2**i -> (0, 2**i)
374+
// ...
375+
// - register=M -> (0, 2**M)
376+
// ...
377+
// - register=2**k -> (1, 0)
378+
// ...
379+
// - register=N -> (2**(N-k), 0)
380+
// - lane=1 -> (2**(N-k+1), 0)
381+
// ...
382+
// - lane=2**j -> (2**(N-k+j), 0)
383+
// ...
384+
// lane=2**M -> (2**(N-k+M), 0)
385+
// where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))]
386+
//
387+
// With N >= M.
388+
int32_t registerInDimSize = conversion->getInDimSize(kRegister);
389+
int32_t laneInDimSize = conversion->getInDimSize(kLane);
390+
return conversion->getBases().lookup(kRegister) ==
391+
buildContiguousSubGroupTransposeRegisterBases(registerInDimSize,
392+
laneInDimSize) &&
393+
conversion->getBases().lookup(kLane) ==
394+
buildContiguousSubGroupTransposeLaneBases(registerInDimSize,
395+
laneInDimSize);
396+
}
397+
228398
bool cvtIsUnbroadcast(RankedTensorType srcTy, RankedTensorType dstTy) {
229399
MLIRContext *ctx = srcTy.getContext();
230400
StringAttr kRegister = str_attr("register");

0 commit comments

Comments
 (0)