Skip to content

Commit 2a72ba2

Browse files
authored
[XPU][TritonIntelGPUToLLVM] Add support for more shuffle kinds (#2799)
Add support for layout conversion shuffles in which rows managed by a single thread are contiguous in the output matrix. Step 2/2 to #2749 --------- Signed-off-by: victor-eds <[email protected]>
1 parent 67ea90d commit 2a72ba2

File tree

4 files changed

+159
-18
lines changed

4 files changed

+159
-18
lines changed

test/Conversion/intel/intel-allocate-shared-memory.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,24 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
1818

1919
// -----
2020

21+
#blocked = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
22+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
23+
24+
// Check no scratch memory is allocated for sub-group shuffle-like layout conversions.
25+
26+
// CHECK-LABEL: module attributes
27+
// CHECK-SAME: triton_gpu.shared = 0 : i32
28+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
29+
// CHECK: tt.func @test_sub_group_shuffle
30+
// CHECK-NOT: llvm.ptr<3>
31+
tt.func @test_sub_group_shuffle(%arg0: tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> {
32+
%0 = triton_gpu.convert_layout %arg0 : tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
33+
tt.return %0 : tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
34+
}
35+
}
36+
37+
// -----
38+
2139
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
2240
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
2341

test/Conversion/intel/sub-group-shuffle.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,3 +360,38 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 :
360360
tt.return %0 : tensor<128xi32, #sliced1>
361361
}
362362
}
363+
364+
// -----
365+
366+
#blocked = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
367+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
368+
369+
// Case of more than one contiguous element per work-item.
370+
371+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
372+
// CHECK-LABEL: llvm.func spir_kernelcc @test_contiguous(
373+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f16, f16)>)
374+
tt.func @test_contiguous(%arg0: tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> {
375+
// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f16, f16)>
376+
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(f16, f16)>
377+
// COM: Check the shuffles are "coalesced"
378+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
379+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
380+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
381+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
382+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
383+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
384+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
385+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
386+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
387+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
388+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
389+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
390+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
391+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
392+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
393+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
394+
%0 = triton_gpu.convert_layout %arg0 : tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
395+
tt.return %0 : tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
396+
}
397+
}

third_party/intel/lib/Analysis/Utility.cpp

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,29 @@ buildSubGroupShuffleRegisterBases(int32_t registerSize, int32_t laneSize) {
7171
return bases;
7272
}
7373

74+
// Return a vector such as:
75+
// [[1, 0], [2, 0], [4, 0], ..., [registerSize / laneSize, 0], [0, 1], ...,
76+
// [0, laneSize/2]]
77+
// i.e., mapping registers to registers till registerSize / laneSize (all
78+
// contiguous registers) and then to lanes.
79+
std::vector<std::vector<int32_t>>
80+
buildContiguousSubGroupShuffleRegisterBases(int32_t registerSize,
81+
int32_t laneSize) {
82+
std::vector<std::vector<int32_t>> bases;
83+
std::vector<int32_t> curr(2);
84+
int i = 1;
85+
for (; i < registerSize / laneSize; i *= 2) {
86+
curr[0] = i;
87+
bases.push_back(curr);
88+
}
89+
curr[0] = 0;
90+
for (int32_t val = 1; i < registerSize; i *= 2, val *= 2) {
91+
curr[1] = val;
92+
bases.push_back(curr);
93+
}
94+
return bases;
95+
}
96+
7497
// Return a vector such as:
7598
// [[1, 0], [2, 0], [4, 0], ..., [laneSize / 2, 0]],
7699
// i.e., mapping lanes to registers.
@@ -138,25 +161,49 @@ bool cvtIsSubGroupShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
138161
// ...
139162
// - register=2**i -> (0, 2**i)
140163
// ...
141-
// - register=M -> (0, 2**M)
164+
// - register=M -> (0, 2**(M-1))
165+
// - register=M+1 -> (1, 0)
142166
// ...
143-
// - register=2**k -> (2**(k-M), 0)
167+
// - register=2**k -> (2**(K-M), 0)
144168
// ...
145169
// - register=2**N -> (2**(N-M), 0)
146170
// - lane=1 -> (0, 0)
147171
// ...
148172
// - lane=2**j -> (0, 0)
149173
// ...
150174
// lane=2**M -> (0, 0)
175+
// where out dims are: [register (size 2**N), lane (size 2**M)]
176+
//
177+
// With N >= M.
178+
//
179+
// Or, when the elements managed by a given work-item are in contiguous
180+
// positions:
181+
// - register=1 -> (1, 0)
182+
// ...
183+
// - register=2**i -> (2**i, 0)
184+
// ...
185+
// - register=M -> (2**(N - M), 0)
186+
// ...
187+
// - register=2**k -> (0, 1)
188+
// ...
189+
// - register=2**N -> (0, 2**(M-1))
190+
// - lane=1 -> (0, 0)
191+
// ...
192+
// - lane=2**j -> (0, 0)
193+
// ...
194+
// lane=2**M -> (0, 0)
151195
// where out dims are: [register (size 2**(N - M)), lane (size 2**(M + 1))]
152196
//
153197
// With N >= M.
154198
int32_t registerInDimSize = conversion->getInDimSize(kRegister);
155199
int32_t laneOutDimSize = conversion->getOutDimSize(kLane);
156200
return conversion->sublayoutIsZero({kLane}, {kRegister, kLane}) &&
157-
conversion->getBases().lookup(kRegister) ==
158-
buildSubGroupShuffleRegisterBases(registerInDimSize,
159-
laneOutDimSize);
201+
(conversion->getBases().lookup(kRegister) ==
202+
buildSubGroupShuffleRegisterBases(registerInDimSize,
203+
laneOutDimSize) ||
204+
conversion->getBases().lookup(kRegister) ==
205+
buildContiguousSubGroupShuffleRegisterBases(registerInDimSize,
206+
laneOutDimSize));
160207
}
161208

162209
bool isValidElementTypeForSubGroupTranspose(Type type) {

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,24 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
560560
return success();
561561
}
562562

563+
int getNumContiguousRowsForShuffle(const LinearLayout &srcLayout,
564+
const LinearLayout &dstLayout) const {
565+
MLIRContext *ctx = getContext();
566+
567+
StringAttr kRegister = str_attr("register");
568+
StringAttr kLane = str_attr("lane");
569+
StringAttr kWarp = str_attr("warp");
570+
StringAttr kBlock = str_attr("block");
571+
LinearLayout comp =
572+
*dstLayout.invertAndCompose(srcLayout).quotient({kWarp, kBlock});
573+
// Basic case: the number of contiguous rows is 1.
574+
if (comp.getBasis(kRegister, 0)[1] == 1)
575+
return 1;
576+
// In other case, we only allow all threads handled by a single element to
577+
// be contiguous, so we can simply:
578+
return comp.getOutDimSize(kRegister);
579+
}
580+
563581
void performSubGroupShuffle(ConvertLayoutOp op, const LinearLayout &srcLayout,
564582
const LinearLayout &dstLayout, OpAdaptor adaptor,
565583
ConversionPatternRewriter &rewriter) const {
@@ -605,8 +623,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
605623
});
606624
});
607625

608-
SmallVector<Value> outVals =
609-
performSubGroupShuffle(loc, inVals, subGroupSize, rewriter);
626+
SmallVector<Value> outVals = performSubGroupShuffle(
627+
loc, inVals, subGroupSize, rewriter,
628+
getNumContiguousRowsForShuffle(srcLayout, dstLayout));
610629

611630
// TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR
612631
// upstream level. We are not enabling support for all types here as that
@@ -636,19 +655,41 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
636655
rewriter.replaceOp(op, result);
637656
}
638657

639-
SmallVector<Value>
640-
performSubGroupShuffle(Location loc, ArrayRef<Value> inVals,
641-
int32_t subGroupSize,
642-
ConversionPatternRewriter &rewriter) const {
658+
SmallVector<Value> performSubGroupShuffle(Location loc,
659+
ArrayRef<Value> inVals,
660+
int32_t subGroupSize,
661+
ConversionPatternRewriter &rewriter,
662+
int numContiguousRows) const {
643663
SmallVector<Value> res;
644664
Value width = i32_val(subGroupSize);
645-
for (Value val : inVals) {
646-
for (int32_t i = 0; i < subGroupSize; ++i)
647-
res.push_back(
648-
rewriter
649-
.create<mlir::gpu::ShuffleOp>(loc, val, i32_val(i), width,
650-
mlir::gpu::ShuffleMode::IDX)
651-
.getShuffleResult());
665+
// A work-item may handle more than one element. There are two cases we
666+
// support:
667+
if (numContiguousRows == 1) {
668+
// 1. Elements held by a work-item are strided rows in the abstract slice
669+
// matrix: Output element `i` will take the `i / 16`th value from the `i %
670+
// 16`th thread.
671+
for (Value val : inVals) {
672+
for (int32_t i = 0; i < subGroupSize; ++i) {
673+
res.push_back(
674+
rewriter
675+
.create<mlir::gpu::ShuffleOp>(loc, val, i32_val(i), width,
676+
mlir::gpu::ShuffleMode::IDX)
677+
.getShuffleResult());
678+
}
679+
}
680+
} else {
681+
// 2. Elements held by a work-item are contiguous rows in the abstract
682+
// slice matrix: Output element `i` will take the `i % 16`th value from
683+
// the `i / 16`th thread.
684+
for (int32_t i = 0; i < subGroupSize; ++i) {
685+
for (Value val : inVals) {
686+
res.push_back(
687+
rewriter
688+
.create<mlir::gpu::ShuffleOp>(loc, val, i32_val(i), width,
689+
mlir::gpu::ShuffleMode::IDX)
690+
.getShuffleResult());
691+
}
692+
}
652693
}
653694
return res;
654695
}

0 commit comments

Comments
 (0)