Skip to content

Commit 5d9774c

Browse files
authored
[TritonIntelGPUToLLVM] Detect more sub-group shuffle convert_layout (#2573)
Detect sub-group shuffle `convert_layout` cases of more than one element per thread. --------- Signed-off-by: victor-eds <[email protected]>
1 parent e438919 commit 5d9774c

File tree

2 files changed

+194
-60
lines changed

2 files changed

+194
-60
lines changed

test/Conversion/intel/sub-group-shuffle.mlir

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,106 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
257257
tt.return %0 : tensor<32xf32, #sliced1>
258258
}
259259
}
260+
261+
// -----
262+
263+
// Case of more than one element per thread in the non-sliced dimension.
264+
265+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
266+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
267+
#sliced = #triton_gpu.slice<{dim = 1, parent = #blocked}>
268+
#sliced1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}>
269+
270+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
271+
// CHECK-LABEL: llvm.func spir_kernelcc @test_non_sliced_multi_register(
272+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f64, f64)>,
273+
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f64, f64)>
274+
// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(f64, f64)>
275+
// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(0 : i32) : i32
276+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_5]])
277+
// CHECK: %[[VAL_8:.*]] = llvm.mlir.constant(1 : i32) : i32
278+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_8]])
279+
// CHECK: llvm.mlir.constant(true) : i1
280+
// CHECK: %[[VAL_11:.*]] = llvm.mlir.constant(2 : i32) : i32
281+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_11]])
282+
// CHECK: %[[VAL_14:.*]] = llvm.mlir.constant(3 : i32) : i32
283+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_14]])
284+
// CHECK: %[[VAL_17:.*]] = llvm.mlir.constant(4 : i32) : i32
285+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_17]])
286+
// CHECK: %[[VAL_20:.*]] = llvm.mlir.constant(5 : i32) : i32
287+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_20]])
288+
// CHECK: %[[VAL_23:.*]] = llvm.mlir.constant(6 : i32) : i32
289+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_23]])
290+
// CHECK: %[[VAL_26:.*]] = llvm.mlir.constant(7 : i32) : i32
291+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_26]])
292+
// CHECK: %[[VAL_29:.*]] = llvm.mlir.constant(8 : i32) : i32
293+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_29]])
294+
// CHECK: %[[VAL_32:.*]] = llvm.mlir.constant(9 : i32) : i32
295+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_32]])
296+
// CHECK: %[[VAL_35:.*]] = llvm.mlir.constant(10 : i32) : i32
297+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_35]])
298+
// CHECK: %[[VAL_38:.*]] = llvm.mlir.constant(11 : i32) : i32
299+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_38]])
300+
// CHECK: %[[VAL_41:.*]] = llvm.mlir.constant(12 : i32) : i32
301+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_41]])
302+
// CHECK: %[[VAL_44:.*]] = llvm.mlir.constant(13 : i32) : i32
303+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_44]])
304+
// CHECK: %[[VAL_47:.*]] = llvm.mlir.constant(14 : i32) : i32
305+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_47]])
306+
// CHECK: %[[VAL_50:.*]] = llvm.mlir.constant(15 : i32) : i32
307+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_50]])
308+
// CHECK: %[[VAL_53:.*]] = llvm.mlir.constant(0 : i32) : i32
309+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_53]])
310+
// CHECK: %[[VAL_56:.*]] = llvm.mlir.constant(1 : i32) : i32
311+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_56]])
312+
// CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(2 : i32) : i32
313+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_59]])
314+
// CHECK: %[[VAL_62:.*]] = llvm.mlir.constant(3 : i32) : i32
315+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_62]])
316+
// CHECK: %[[VAL_65:.*]] = llvm.mlir.constant(4 : i32) : i32
317+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_65]])
318+
// CHECK: %[[VAL_68:.*]] = llvm.mlir.constant(5 : i32) : i32
319+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_68]])
320+
// CHECK: %[[VAL_71:.*]] = llvm.mlir.constant(6 : i32) : i32
321+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_71]])
322+
// CHECK: %[[VAL_74:.*]] = llvm.mlir.constant(7 : i32) : i32
323+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_74]])
324+
// CHECK: %[[VAL_77:.*]] = llvm.mlir.constant(8 : i32) : i32
325+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_77]])
326+
// CHECK: %[[VAL_80:.*]] = llvm.mlir.constant(9 : i32) : i32
327+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_80]])
328+
// CHECK: %[[VAL_83:.*]] = llvm.mlir.constant(10 : i32) : i32
329+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_83]])
330+
// CHECK: %[[VAL_86:.*]] = llvm.mlir.constant(11 : i32) : i32
331+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_86]])
332+
// CHECK: %[[VAL_89:.*]] = llvm.mlir.constant(12 : i32) : i32
333+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_89]])
334+
// CHECK: %[[VAL_92:.*]] = llvm.mlir.constant(13 : i32) : i32
335+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_92]])
336+
// CHECK: %[[VAL_95:.*]] = llvm.mlir.constant(14 : i32) : i32
337+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_95]])
338+
// CHECK: %[[VAL_98:.*]] = llvm.mlir.constant(15 : i32) : i32
339+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_98]])
340+
tt.func @test_non_sliced_multi_register(%arg0: tensor<32xf64, #sliced>) -> tensor<32xf64, #sliced1> {
341+
%0 = triton_gpu.convert_layout %arg0 : tensor<32xf64, #sliced> -> tensor<32xf64, #sliced1>
342+
tt.return %0 : tensor<32xf64, #sliced1>
343+
}
344+
}
345+
346+
// -----
347+
348+
// Case of more than one element per thread and 2 warps in the non-sliced dimension.
349+
350+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 1], order = [0, 1]}>
351+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 1], order = [0, 1]}>
352+
#sliced = #triton_gpu.slice<{dim = 1, parent = #blocked}>
353+
#sliced1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}>
354+
355+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
356+
// CHECK-LABEL: llvm.func spir_kernelcc @test_non_sliced_multi_register_multi_warp
357+
// CHECK-COUNT-64: llvm.call spir_funccc @_Z17sub_group_shuffleij
358+
tt.func @test_non_sliced_multi_register_multi_warp(%arg0: tensor<128xi32, #sliced>) -> tensor<128xi32, #sliced1> {
359+
%0 = triton_gpu.convert_layout %arg0 : tensor<128xi32, #sliced> -> tensor<128xi32, #sliced1>
360+
tt.return %0 : tensor<128xi32, #sliced1>
361+
}
362+
}

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 91 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,62 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
446446
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
447447
}
448448

449+
// Return a vector such as:
450+
// [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [laneSize, 0], ...,
451+
// [registerSize / 2, 0]],
452+
// i.e., mapping registers to lanes till laneSize and performing an ID
453+
// conversion afterwards.
454+
static std::vector<std::vector<int32_t>>
455+
buildSubGroupTransposeRegisterBases(int32_t registerSize, int32_t laneSize) {
456+
std::vector<std::vector<int32_t>> bases;
457+
std::vector<int32_t> curr(2);
458+
for (int32_t i = 1; i < laneSize; i *= 2) {
459+
curr[1] = i;
460+
bases.push_back(curr);
461+
}
462+
curr[1] = 0;
463+
for (int32_t i = laneSize; i < registerSize; i *= 2) {
464+
curr[0] = i;
465+
bases.push_back(curr);
466+
}
467+
return bases;
468+
}
469+
470+
// Return a vector such as:
471+
// [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [1, 0], ...,
472+
// [registerSize / (2 * laneSize), 0]]
473+
// i.e., mapping registers to lanes till laneSize and repeating the pattern
474+
// afterwards.
475+
static std::vector<std::vector<int32_t>>
476+
buildSubGroupShuffleRegisterBases(int32_t registerSize, int32_t laneSize) {
477+
std::vector<std::vector<int32_t>> bases;
478+
std::vector<int32_t> curr(2);
479+
for (int32_t i = 1; i < laneSize; i *= 2) {
480+
curr[1] = i;
481+
bases.push_back(curr);
482+
}
483+
curr[1] = 0;
484+
for (int32_t i = laneSize, val = 1; i < registerSize; i *= 2, val *= 2) {
485+
curr[0] = val;
486+
bases.push_back(curr);
487+
}
488+
return bases;
489+
}
490+
491+
// Return a vector such as:
492+
// [[1, 0], [2, 0], [4, 0], ..., [laneSize / 2, 0]],
493+
// i.e., mapping lanes to registers.
494+
static std::vector<std::vector<int32_t>>
495+
buildSubGroupTransposeLaneBases(int32_t laneSize) {
496+
std::vector<std::vector<int32_t>> bases;
497+
std::vector<int32_t> curr(2);
498+
for (int32_t i = 1; i < laneSize; i *= 2) {
499+
curr[0] = i;
500+
bases.push_back(curr);
501+
}
502+
return bases;
503+
}
504+
449505
bool isSubGroupTranspose(const LinearLayout &srcLayout,
450506
const LinearLayout &dstLayout) const {
451507
MLIRContext *ctx = srcLayout.getInDimNames().begin()->getContext();
@@ -476,35 +532,13 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
476532
// where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))]
477533
//
478534
// With N >= M.
479-
const auto buildBasis = [&](int32_t size, std::size_t index) {
480-
std::vector<std::vector<int32_t>> basis;
481-
std::vector<int32_t> curr(2);
482-
for (int32_t i = 1; i < size; i *= 2) {
483-
curr[index] = i;
484-
basis.push_back(curr);
485-
}
486-
return basis;
487-
};
488-
constexpr std::size_t laneIndex = 0;
489-
constexpr std::size_t registerIndex = 1;
490-
int32_t laneSize = conversion->getInDimSize(kLane);
491-
std::vector<std::vector<int32_t>> registerBases =
492-
buildBasis(laneSize, registerIndex);
493-
{
494-
// Populate register bases for N > M.
495-
std::vector<int32_t> base(2);
496-
for (int32_t i = laneSize,
497-
registerSize = conversion->getInDimSize(kRegister);
498-
i < registerSize; i *= 2) {
499-
base[laneIndex] = i;
500-
registerBases.push_back(base);
501-
}
502-
}
503-
std::array<std::pair<StringAttr, std::vector<std::vector<int32_t>>>, 2>
504-
bases{{{kRegister, std::move(registerBases)},
505-
{kLane, buildBasis(laneSize, laneIndex)}}};
506-
std::array<StringAttr, 2> outDimNames{kRegister, kLane};
507-
return conversion == LinearLayout(bases, outDimNames);
535+
int32_t registerInDimSize = conversion->getInDimSize(kRegister);
536+
int32_t laneInDimSize = conversion->getInDimSize(kLane);
537+
return conversion->getBases().lookup(kRegister) ==
538+
buildSubGroupTransposeRegisterBases(registerInDimSize,
539+
laneInDimSize) &&
540+
conversion->getBases().lookup(kLane) ==
541+
buildSubGroupTransposeLaneBases(laneInDimSize);
508542
}
509543

510544
LogicalResult
@@ -619,32 +653,27 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
619653
// Expected conversion is:
620654
// - register=1 -> (0, 1)
621655
// ...
622-
// register=i -> (0, i)
656+
// - register=2**i -> (0, 2**i)
657+
// ...
658+
// - register=M -> (0, 2**M)
659+
// ...
660+
// - register=2**k -> (2**(k-M), 0)
623661
// ...
624-
// register=N -> (0, N)
662+
// - register=2**N -> (2**(N-M), 0)
625663
// - lane=1 -> (0, 0)
626664
// ...
627-
// lane=i -> (0, 0)
665+
// - lane=2**j -> (0, 0)
628666
// ...
629-
// lane=N -> (0, 0)
630-
// where out dims are: [register (size 1), lane (size N)]
631-
std::vector<std::vector<int32_t>> registerBases;
632-
{
633-
constexpr std::size_t registerIndex = 1;
634-
std::vector<int32_t> base(2);
635-
for (int32_t i = 1, n = conversion->getInDimSize(kLane); i < n; i *= 2) {
636-
base[registerIndex] = i;
637-
registerBases.push_back(base);
638-
}
639-
}
640-
641-
std::vector<std::vector<int32_t>> laneBases(
642-
conversion->getInDimSizeLog2(kLane), std::vector<int32_t>{0, 0});
643-
std::array<std::pair<StringAttr, std::vector<std::vector<int32_t>>>, 2>
644-
bases{{{kRegister, std::move(registerBases)},
645-
{kLane, std::move(laneBases)}}};
646-
std::array<StringAttr, 2> outDimNames{kRegister, kLane};
647-
return conversion == LinearLayout(bases, outDimNames);
667+
// lane=2**M -> (0, 0)
668+
// where out dims are: [register (size 2**(N - M)), lane (size 2**(M + 1))]
669+
//
670+
// With N >= M.
671+
int32_t registerInDimSize = conversion->getInDimSize(kRegister);
672+
int32_t laneOutDimSize = conversion->getOutDimSize(kLane);
673+
return conversion->sublayoutIsZero({kLane}, {kRegister, kLane}) &&
674+
conversion->getBases().lookup(kRegister) ==
675+
buildSubGroupShuffleRegisterBases(registerInDimSize,
676+
laneOutDimSize);
648677
}
649678

650679
bool isSupportedSubGroupShuffle(ConvertLayoutOp, OpAdaptor) const {
@@ -674,7 +703,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
674703

675704
SmallVector<Value> inVals =
676705
unpackLLElements(loc, adaptor.getSrc(), rewriter);
677-
assert(inVals.size() == 1 && "Expecting single element");
678706

679707
// TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR
680708
// upstream level. We are not enabling support for all types here as that
@@ -703,7 +731,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
703731
});
704732

705733
SmallVector<Value> outVals =
706-
performSubGroupShuffle(loc, inVals.front(), subGroupSize, rewriter);
734+
performSubGroupShuffle(loc, inVals, subGroupSize, rewriter);
707735

708736
// TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR
709737
// upstream level. We are not enabling support for all types here as that
@@ -734,16 +762,19 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
734762
}
735763

736764
SmallVector<Value>
737-
performSubGroupShuffle(Location loc, Value val, int32_t subGroupSize,
765+
performSubGroupShuffle(Location loc, ArrayRef<Value> inVals,
766+
int32_t subGroupSize,
738767
ConversionPatternRewriter &rewriter) const {
739768
SmallVector<Value> res;
740769
Value width = i32_val(subGroupSize);
741-
for (int32_t i = 0; i < subGroupSize; ++i)
742-
res.push_back(
743-
rewriter
744-
.create<mlir::gpu::ShuffleOp>(loc, val, i32_val(i), width,
745-
mlir::gpu::ShuffleMode::IDX)
746-
.getShuffleResult());
770+
for (Value val : inVals) {
771+
for (int32_t i = 0; i < subGroupSize; ++i)
772+
res.push_back(
773+
rewriter
774+
.create<mlir::gpu::ShuffleOp>(loc, val, i32_val(i), width,
775+
mlir::gpu::ShuffleMode::IDX)
776+
.getShuffleResult());
777+
}
747778
return res;
748779
}
749780

0 commit comments

Comments
 (0)