Skip to content

Commit e580308

Browse files
committed
[TritonIntelGPUToLLVM] Detect more sub-group shuffle convert_layout
Detect sub-group shuffle `convert_layout` cases of more than one element per thread. Signed-off-by: victor-eds <[email protected]>
1 parent e6df65e commit e580308

File tree

2 files changed

+183
-61
lines changed

2 files changed

+183
-61
lines changed

test/Conversion/intel/sub-group-shuffle.mlir

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,106 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
257257
tt.return %0 : tensor<32xf32, #sliced1>
258258
}
259259
}
260+
261+
// -----
262+
263+
// Case of more than one element per thread in the non-sliced dimension.
264+
265+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
266+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
267+
#sliced = #triton_gpu.slice<{dim = 1, parent = #blocked}>
268+
#sliced1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}>
269+
270+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
271+
// CHECK-LABEL: llvm.func spir_kernelcc @test(
272+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f64, f64)>,
273+
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f64, f64)>
274+
// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(f64, f64)>
275+
// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(0 : i32) : i32
276+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_5]])
277+
// CHECK: %[[VAL_8:.*]] = llvm.mlir.constant(1 : i32) : i32
278+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_8]])
279+
// CHECK: llvm.mlir.constant(true) : i1
280+
// CHECK: %[[VAL_11:.*]] = llvm.mlir.constant(2 : i32) : i32
281+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_11]])
282+
// CHECK: %[[VAL_14:.*]] = llvm.mlir.constant(3 : i32) : i32
283+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_14]])
284+
// CHECK: %[[VAL_17:.*]] = llvm.mlir.constant(4 : i32) : i32
285+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_17]])
286+
// CHECK: %[[VAL_20:.*]] = llvm.mlir.constant(5 : i32) : i32
287+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_20]])
288+
// CHECK: %[[VAL_23:.*]] = llvm.mlir.constant(6 : i32) : i32
289+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_23]])
290+
// CHECK: %[[VAL_26:.*]] = llvm.mlir.constant(7 : i32) : i32
291+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_26]])
292+
// CHECK: %[[VAL_29:.*]] = llvm.mlir.constant(8 : i32) : i32
293+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_29]])
294+
// CHECK: %[[VAL_32:.*]] = llvm.mlir.constant(9 : i32) : i32
295+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_32]])
296+
// CHECK: %[[VAL_35:.*]] = llvm.mlir.constant(10 : i32) : i32
297+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_35]])
298+
// CHECK: %[[VAL_38:.*]] = llvm.mlir.constant(11 : i32) : i32
299+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_38]])
300+
// CHECK: %[[VAL_41:.*]] = llvm.mlir.constant(12 : i32) : i32
301+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_41]])
302+
// CHECK: %[[VAL_44:.*]] = llvm.mlir.constant(13 : i32) : i32
303+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_44]])
304+
// CHECK: %[[VAL_47:.*]] = llvm.mlir.constant(14 : i32) : i32
305+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_47]])
306+
// CHECK: %[[VAL_50:.*]] = llvm.mlir.constant(15 : i32) : i32
307+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_50]])
308+
// CHECK: %[[VAL_53:.*]] = llvm.mlir.constant(0 : i32) : i32
309+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_53]])
310+
// CHECK: %[[VAL_56:.*]] = llvm.mlir.constant(1 : i32) : i32
311+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_56]])
312+
// CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(2 : i32) : i32
313+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_59]])
314+
// CHECK: %[[VAL_62:.*]] = llvm.mlir.constant(3 : i32) : i32
315+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_62]])
316+
// CHECK: %[[VAL_65:.*]] = llvm.mlir.constant(4 : i32) : i32
317+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_65]])
318+
// CHECK: %[[VAL_68:.*]] = llvm.mlir.constant(5 : i32) : i32
319+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_68]])
320+
// CHECK: %[[VAL_71:.*]] = llvm.mlir.constant(6 : i32) : i32
321+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_71]])
322+
// CHECK: %[[VAL_74:.*]] = llvm.mlir.constant(7 : i32) : i32
323+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_74]])
324+
// CHECK: %[[VAL_77:.*]] = llvm.mlir.constant(8 : i32) : i32
325+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_77]])
326+
// CHECK: %[[VAL_80:.*]] = llvm.mlir.constant(9 : i32) : i32
327+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_80]])
328+
// CHECK: %[[VAL_83:.*]] = llvm.mlir.constant(10 : i32) : i32
329+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_83]])
330+
// CHECK: %[[VAL_86:.*]] = llvm.mlir.constant(11 : i32) : i32
331+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_86]])
332+
// CHECK: %[[VAL_89:.*]] = llvm.mlir.constant(12 : i32) : i32
333+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_89]])
334+
// CHECK: %[[VAL_92:.*]] = llvm.mlir.constant(13 : i32) : i32
335+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_92]])
336+
// CHECK: %[[VAL_95:.*]] = llvm.mlir.constant(14 : i32) : i32
337+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_95]])
338+
// CHECK: %[[VAL_98:.*]] = llvm.mlir.constant(15 : i32) : i32
339+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_98]])
340+
tt.func @test(%arg0: tensor<32xf64, #sliced>) -> tensor<32xf64, #sliced1> {
341+
%0 = triton_gpu.convert_layout %arg0 : tensor<32xf64, #sliced> -> tensor<32xf64, #sliced1>
342+
tt.return %0 : tensor<32xf64, #sliced1>
343+
}
344+
}
345+
346+
// -----
347+
348+
// Case of more than one element per thread and 2 warps in the non-sliced dimension.
349+
350+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 1], order = [0, 1]}>
351+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 1], order = [0, 1]}>
352+
#sliced = #triton_gpu.slice<{dim = 1, parent = #blocked}>
353+
#sliced1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}>
354+
355+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
356+
// CHECK-LABEL: llvm.func spir_kernelcc @test
357+
// CHECK-COUNT-64: llvm.call spir_funccc @_Z17sub_group_shuffleij
358+
tt.func @test(%arg0: tensor<128xi32, #sliced>) -> tensor<128xi32, #sliced1> {
359+
%0 = triton_gpu.convert_layout %arg0 : tensor<128xi32, #sliced> -> tensor<128xi32, #sliced1>
360+
tt.return %0 : tensor<128xi32, #sliced1>
361+
}
362+
}

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 80 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,57 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
446446
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
447447
}
448448

449+
// Return a vector such as:
450+
// [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [laneSize, 0], ...,
451+
// [registerSize / 2, 0]]
452+
static std::vector<std::vector<int32_t>>
453+
buildSubGroupTransposeRegisterBases(int32_t registerSize, int32_t laneSize) {
454+
std::vector<std::vector<int32_t>> bases;
455+
std::vector<int32_t> curr(2);
456+
for (int32_t i = 1; i < laneSize; i *= 2) {
457+
curr[1] = i;
458+
bases.push_back(curr);
459+
}
460+
curr[1] = 0;
461+
for (int32_t i = laneSize; i < registerSize; i *= 2) {
462+
curr[0] = i;
463+
bases.push_back(curr);
464+
}
465+
return bases;
466+
}
467+
468+
// Return a vector such as:
469+
// [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [1, 0], ...,
470+
// [registerSize / (2 * laneSize), 0]]
471+
static std::vector<std::vector<int32_t>>
472+
buildSubGroupShuffleRegisterBases(int32_t registerSize, int32_t laneSize) {
473+
std::vector<std::vector<int32_t>> bases;
474+
std::vector<int32_t> curr(2);
475+
for (int32_t i = 1; i < laneSize; i *= 2) {
476+
curr[1] = i;
477+
bases.push_back(curr);
478+
}
479+
curr[1] = 0;
480+
for (int32_t i = laneSize, val = 1; i < registerSize; i *= 2, val *= 2) {
481+
curr[0] = val;
482+
bases.push_back(curr);
483+
}
484+
return bases;
485+
}
486+
487+
// Return a vector such as:
488+
// [[1, 0], [2, 0], [4, 0], ..., [laneSize / 2, 0]]
489+
static std::vector<std::vector<int32_t>>
490+
buildSubGroupTransposeLaneBases(int32_t laneSize) {
491+
std::vector<std::vector<int32_t>> bases;
492+
std::vector<int32_t> curr(2);
493+
for (int32_t i = 1; i < laneSize; i *= 2) {
494+
curr[0] = i;
495+
bases.push_back(curr);
496+
}
497+
return bases;
498+
}
499+
449500
bool isSubGroupTranspose(const LinearLayout &srcLayout,
450501
const LinearLayout &dstLayout) const {
451502
MLIRContext *ctx = srcLayout.getInDimNames().begin()->getContext();
@@ -476,35 +527,13 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
476527
// where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))]
477528
//
478529
// With N >= M.
479-
const auto buildBasis = [&](int32_t size, std::size_t index) {
480-
std::vector<std::vector<int32_t>> basis;
481-
std::vector<int32_t> curr(2);
482-
for (int32_t i = 1; i < size; i *= 2) {
483-
curr[index] = i;
484-
basis.push_back(curr);
485-
}
486-
return basis;
487-
};
488-
constexpr std::size_t laneIndex = 0;
489-
constexpr std::size_t registerIndex = 1;
490-
int32_t laneSize = conversion->getInDimSize(kLane);
491-
std::vector<std::vector<int32_t>> registerBases =
492-
buildBasis(laneSize, registerIndex);
493-
{
494-
// Populate register bases for N > M.
495-
std::vector<int32_t> base(2);
496-
for (int32_t i = laneSize,
497-
registerSize = conversion->getInDimSize(kRegister);
498-
i < registerSize; i *= 2) {
499-
base[laneIndex] = i;
500-
registerBases.push_back(base);
501-
}
502-
}
503-
std::array<std::pair<StringAttr, std::vector<std::vector<int32_t>>>, 2>
504-
bases{{{kRegister, std::move(registerBases)},
505-
{kLane, buildBasis(laneSize, laneIndex)}}};
506-
std::array<StringAttr, 2> outDimNames{kRegister, kLane};
507-
return conversion == LinearLayout(bases, outDimNames);
530+
int32_t registerInDimSize = conversion->getInDimSize(kRegister);
531+
int32_t laneInDimSize = conversion->getInDimSize(kLane);
532+
return conversion->getBases().lookup(kRegister) ==
533+
buildSubGroupTransposeRegisterBases(registerInDimSize,
534+
laneInDimSize) &&
535+
conversion->getBases().lookup(kLane) ==
536+
buildSubGroupTransposeLaneBases(laneInDimSize);
508537
}
509538

510539
LogicalResult
@@ -619,32 +648,20 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
619648
// Expected conversion is:
620649
// - register=1 -> (0, 1)
621650
// ...
622-
// register=i -> (0, i)
651+
// - register=2**i -> (0, 2**i)
623652
// ...
624-
// register=N -> (0, N)
625-
// - lane=1 -> (0, 0)
653+
// - register=M -> (0, 2**M)
626654
// ...
627-
// lane=i -> (0, 0)
655+
// - register=2**k -> (2**(k-M), 0)
628656
// ...
629-
// lane=N -> (0, 0)
630-
// where out dims are: [register (size 1), lane (size N)]
631-
std::vector<std::vector<int32_t>> registerBases;
632-
{
633-
constexpr std::size_t registerIndex = 1;
634-
std::vector<int32_t> base(2);
635-
for (int32_t i = 1, n = conversion->getInDimSize(kLane); i < n; i *= 2) {
636-
base[registerIndex] = i;
637-
registerBases.push_back(base);
638-
}
639-
}
640-
641-
std::vector<std::vector<int32_t>> laneBases(
642-
conversion->getInDimSizeLog2(kLane), std::vector<int32_t>{0, 0});
643-
std::array<std::pair<StringAttr, std::vector<std::vector<int32_t>>>, 2>
644-
bases{{{kRegister, std::move(registerBases)},
645-
{kLane, std::move(laneBases)}}};
646-
std::array<StringAttr, 2> outDimNames{kRegister, kLane};
647-
return conversion == LinearLayout(bases, outDimNames);
657+
// - register=2**N -> (2**(N-M), 0)
658+
// where out dims are: [register (size 2**(N - M)), lane (size 2**(M + 1))]
659+
//
660+
// With N >= M.
661+
int32_t registerInDimSize = conversion->getInDimSize(kRegister);
662+
int32_t laneOutDimSize = conversion->getOutDimSize(kLane);
663+
return conversion->getBases().lookup(kRegister) ==
664+
buildSubGroupShuffleRegisterBases(registerInDimSize, laneOutDimSize);
648665
}
649666

650667
bool isSupportedSubGroupShuffle(ConvertLayoutOp, OpAdaptor) const {
@@ -674,7 +691,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
674691

675692
SmallVector<Value> inVals =
676693
unpackLLElements(loc, adaptor.getSrc(), rewriter);
677-
assert(inVals.size() == 1 && "Expecting single element");
678694

679695
// TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR
680696
// upstream level. We are not enabling support for all types here as that
@@ -703,7 +719,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
703719
});
704720

705721
SmallVector<Value> outVals =
706-
performSubGroupShuffle(loc, inVals.front(), subGroupSize, rewriter);
722+
performSubGroupShuffle(loc, inVals, subGroupSize, rewriter);
707723

708724
// TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR
709725
// upstream level. We are not enabling support for all types here as that
@@ -734,16 +750,19 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
734750
}
735751

736752
SmallVector<Value>
737-
performSubGroupShuffle(Location loc, Value val, int32_t subGroupSize,
753+
performSubGroupShuffle(Location loc, ArrayRef<Value> inVals,
754+
int32_t subGroupSize,
738755
ConversionPatternRewriter &rewriter) const {
739756
SmallVector<Value> res;
740757
Value width = i32_val(subGroupSize);
741-
for (int32_t i = 0; i < subGroupSize; ++i)
742-
res.push_back(
743-
rewriter
744-
.create<mlir::gpu::ShuffleOp>(loc, val, i32_val(i), width,
745-
mlir::gpu::ShuffleMode::IDX)
746-
.getShuffleResult());
758+
for (Value val : inVals) {
759+
for (int32_t i = 0; i < subGroupSize; ++i)
760+
res.push_back(
761+
rewriter
762+
.create<mlir::gpu::ShuffleOp>(loc, val, i32_val(i), width,
763+
mlir::gpu::ShuffleMode::IDX)
764+
.getShuffleResult());
765+
}
747766
return res;
748767
}
749768

0 commit comments

Comments
 (0)