Skip to content

Commit 6e24d02

Browse files
committed
[TritonIntelGPUToLLVM] Detect more sub-group shuffle convert_layout
Detect sub-group shuffle `convert_layout` cases of more than one element per thread. Signed-off-by: victor-eds <[email protected]>
1 parent 529ca78 commit 6e24d02

File tree

2 files changed

+186
-62
lines changed

2 files changed

+186
-62
lines changed

test/Conversion/intel/sub-group-shuffle.mlir

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,106 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
257257
tt.return %0 : tensor<32xf32, #sliced1>
258258
}
259259
}
260+
261+
// -----
262+
263+
// Case of more than one element per thread in the non-sliced dimension.
264+
265+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
266+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
267+
#sliced = #triton_gpu.slice<{dim = 1, parent = #blocked}>
268+
#sliced1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}>
269+
270+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
271+
// CHECK-LABEL: llvm.func spir_kernelcc @test(
272+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f64, f64)>,
273+
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f64, f64)>
274+
// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(f64, f64)>
275+
// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(0 : i32) : i32
276+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_5]])
277+
// CHECK: %[[VAL_8:.*]] = llvm.mlir.constant(1 : i32) : i32
278+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_8]])
279+
// CHECK: llvm.mlir.constant(true) : i1
280+
// CHECK: %[[VAL_11:.*]] = llvm.mlir.constant(2 : i32) : i32
281+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_11]])
282+
// CHECK: %[[VAL_14:.*]] = llvm.mlir.constant(3 : i32) : i32
283+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_14]])
284+
// CHECK: %[[VAL_17:.*]] = llvm.mlir.constant(4 : i32) : i32
285+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_17]])
286+
// CHECK: %[[VAL_20:.*]] = llvm.mlir.constant(5 : i32) : i32
287+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_20]])
288+
// CHECK: %[[VAL_23:.*]] = llvm.mlir.constant(6 : i32) : i32
289+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_23]])
290+
// CHECK: %[[VAL_26:.*]] = llvm.mlir.constant(7 : i32) : i32
291+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_26]])
292+
// CHECK: %[[VAL_29:.*]] = llvm.mlir.constant(8 : i32) : i32
293+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_29]])
294+
// CHECK: %[[VAL_32:.*]] = llvm.mlir.constant(9 : i32) : i32
295+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_32]])
296+
// CHECK: %[[VAL_35:.*]] = llvm.mlir.constant(10 : i32) : i32
297+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_35]])
298+
// CHECK: %[[VAL_38:.*]] = llvm.mlir.constant(11 : i32) : i32
299+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_38]])
300+
// CHECK: %[[VAL_41:.*]] = llvm.mlir.constant(12 : i32) : i32
301+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_41]])
302+
// CHECK: %[[VAL_44:.*]] = llvm.mlir.constant(13 : i32) : i32
303+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_44]])
304+
// CHECK: %[[VAL_47:.*]] = llvm.mlir.constant(14 : i32) : i32
305+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_47]])
306+
// CHECK: %[[VAL_50:.*]] = llvm.mlir.constant(15 : i32) : i32
307+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_50]])
308+
// CHECK: %[[VAL_53:.*]] = llvm.mlir.constant(0 : i32) : i32
309+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_53]])
310+
// CHECK: %[[VAL_56:.*]] = llvm.mlir.constant(1 : i32) : i32
311+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_56]])
312+
// CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(2 : i32) : i32
313+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_59]])
314+
// CHECK: %[[VAL_62:.*]] = llvm.mlir.constant(3 : i32) : i32
315+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_62]])
316+
// CHECK: %[[VAL_65:.*]] = llvm.mlir.constant(4 : i32) : i32
317+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_65]])
318+
// CHECK: %[[VAL_68:.*]] = llvm.mlir.constant(5 : i32) : i32
319+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_68]])
320+
// CHECK: %[[VAL_71:.*]] = llvm.mlir.constant(6 : i32) : i32
321+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_71]])
322+
// CHECK: %[[VAL_74:.*]] = llvm.mlir.constant(7 : i32) : i32
323+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_74]])
324+
// CHECK: %[[VAL_77:.*]] = llvm.mlir.constant(8 : i32) : i32
325+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_77]])
326+
// CHECK: %[[VAL_80:.*]] = llvm.mlir.constant(9 : i32) : i32
327+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_80]])
328+
// CHECK: %[[VAL_83:.*]] = llvm.mlir.constant(10 : i32) : i32
329+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_83]])
330+
// CHECK: %[[VAL_86:.*]] = llvm.mlir.constant(11 : i32) : i32
331+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_86]])
332+
// CHECK: %[[VAL_89:.*]] = llvm.mlir.constant(12 : i32) : i32
333+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_89]])
334+
// CHECK: %[[VAL_92:.*]] = llvm.mlir.constant(13 : i32) : i32
335+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_92]])
336+
// CHECK: %[[VAL_95:.*]] = llvm.mlir.constant(14 : i32) : i32
337+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_95]])
338+
// CHECK: %[[VAL_98:.*]] = llvm.mlir.constant(15 : i32) : i32
339+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_98]])
340+
tt.func @test(%arg0: tensor<32xf64, #sliced>) -> tensor<32xf64, #sliced1> {
341+
%0 = triton_gpu.convert_layout %arg0 : tensor<32xf64, #sliced> -> tensor<32xf64, #sliced1>
342+
tt.return %0 : tensor<32xf64, #sliced1>
343+
}
344+
}
345+
346+
// -----
347+
348+
// Case of more than one element per thread and 2 warps in the non-sliced dimension.
349+
350+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 1], order = [0, 1]}>
351+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 1], order = [0, 1]}>
352+
#sliced = #triton_gpu.slice<{dim = 1, parent = #blocked}>
353+
#sliced1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}>
354+
355+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
356+
// CHECK-LABEL: llvm.func spir_kernelcc @test
357+
// CHECK-COUNT-64: llvm.call spir_funccc @_Z17sub_group_shuffleij
358+
tt.func @test(%arg0: tensor<128xi32, #sliced>) -> tensor<128xi32, #sliced1> {
359+
%0 = triton_gpu.convert_layout %arg0 : tensor<128xi32, #sliced> -> tensor<128xi32, #sliced1>
360+
tt.return %0 : tensor<128xi32, #sliced1>
361+
}
362+
}

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 83 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,57 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
446446
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
447447
}
448448

449+
// Return a vector such as:
450+
// [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [laneSize, 0], ...,
451+
// [registerSize / 2, 0]]
452+
static std::vector<std::vector<int32_t>>
453+
buildSubGroupTransposeRegisterBases(int32_t registerSize, int32_t laneSize) {
454+
std::vector<std::vector<int32_t>> bases;
455+
std::vector<int32_t> curr(2);
456+
for (int32_t i = 1; i < laneSize; i *= 2) {
457+
curr[1] = i;
458+
bases.push_back(curr);
459+
}
460+
curr[1] = 0;
461+
for (int32_t i = laneSize; i < registerSize; i *= 2) {
462+
curr[0] = i;
463+
bases.push_back(curr);
464+
}
465+
return bases;
466+
}
467+
468+
// Return a vector such as:
469+
// [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [1, 0], ...,
470+
// [registerSize / (2 * laneSize), 0]]
471+
static std::vector<std::vector<int32_t>>
472+
buildSubGroupShuffleRegisterBases(int32_t registerSize, int32_t laneSize) {
473+
std::vector<std::vector<int32_t>> bases;
474+
std::vector<int32_t> curr(2);
475+
for (int32_t i = 1; i < laneSize; i *= 2) {
476+
curr[1] = i;
477+
bases.push_back(curr);
478+
}
479+
curr[1] = 0;
480+
for (int32_t i = laneSize, val = 1; i < registerSize; i *= 2, val *= 2) {
481+
curr[0] = val;
482+
bases.push_back(curr);
483+
}
484+
return bases;
485+
}
486+
487+
// Return a vector such as:
488+
// [[1, 0], [2, 0], [4, 0], ..., [laneSize / 2, 0]]
489+
static std::vector<std::vector<int32_t>>
490+
buildSubGroupTransposeLaneBases(int32_t laneSize) {
491+
std::vector<std::vector<int32_t>> bases;
492+
std::vector<int32_t> curr(2);
493+
for (int32_t i = 1; i < laneSize; i *= 2) {
494+
curr[0] = i;
495+
bases.push_back(curr);
496+
}
497+
return bases;
498+
}
499+
449500
bool isSubGroupTranspose(const LinearLayout &srcLayout,
450501
const LinearLayout &dstLayout) const {
451502
MLIRContext *ctx = srcLayout.getInDimNames().begin()->getContext();
@@ -477,35 +528,13 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
477528
// where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))]
478529
//
479530
// With N >= M.
480-
const auto buildBasis = [&](int32_t size, std::size_t index) {
481-
std::vector<std::vector<int32_t>> basis;
482-
std::vector<int32_t> curr(2);
483-
for (int32_t i = 1; i < size; i *= 2) {
484-
curr[index] = i;
485-
basis.push_back(curr);
486-
}
487-
return basis;
488-
};
489-
constexpr std::size_t laneIndex = 0;
490-
constexpr std::size_t registerIndex = 1;
491-
int32_t laneSize = conversion->getInDimSize(kLane);
492-
std::vector<std::vector<int32_t>> registerBases =
493-
buildBasis(laneSize, registerIndex);
494-
{
495-
// Populate register bases for N > M.
496-
std::vector<int32_t> base(2);
497-
for (int32_t i = laneSize,
498-
registerSize = conversion->getInDimSize(kRegister);
499-
i < registerSize; i *= 2) {
500-
base[laneIndex] = i;
501-
registerBases.push_back(base);
502-
}
503-
}
504-
std::array<std::pair<StringAttr, std::vector<std::vector<int32_t>>>, 2>
505-
bases{{{kRegister, std::move(registerBases)},
506-
{kLane, buildBasis(laneSize, laneIndex)}}};
507-
std::array<StringAttr, 2> outDimNames{kRegister, kLane};
508-
return conversion == LinearLayout(bases, outDimNames);
531+
int32_t registerInDimSize = conversion->getInDimSize(kRegister);
532+
int32_t laneInDimSize = conversion->getInDimSize(kLane);
533+
return conversion->getBases().lookup(kRegister) ==
534+
buildSubGroupTransposeRegisterBases(registerInDimSize,
535+
laneInDimSize) &&
536+
conversion->getBases().lookup(kLane) ==
537+
buildSubGroupTransposeLaneBases(laneInDimSize);
509538
}
510539

511540
LogicalResult
@@ -612,39 +641,29 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
612641

613642
LinearLayout comp = dstLayout.invertAndCompose(srcLayout);
614643
std::optional<LinearLayout> conversion = comp.divideRight(
644+
LinearLayout::zeros1D(comp.getInDimSize(kLane), kLane, kLane) *
615645
LinearLayout::identity1D(comp.getInDimSize(kWarp), kWarp, kWarp) *
616646
LinearLayout::identity1D(comp.getInDimSize(kBlock), kBlock, kBlock));
617-
assert(conversion && "Expecting valid conversion");
647+
if (!conversion)
648+
return false;
618649
// TODO: Support more kind of shuffles.
619650
// Expected conversion is:
620651
// - register=1 -> (0, 1)
621652
// ...
622-
// register=i -> (0, i)
653+
// - register=2**i -> (0, 2**i)
623654
// ...
624-
// register=N -> (0, N)
625-
// - lane=1 -> (0, 0)
655+
// - register=M -> (0, 2**M)
626656
// ...
627-
// lane=i -> (0, 0)
657+
// - register=2**k -> (2**(k-M), 0)
628658
// ...
629-
// lane=N -> (0, 0)
630-
// where out dims are: [register (size 1), lane (size N)]
631-
std::vector<std::vector<int32_t>> registerBases;
632-
{
633-
constexpr std::size_t registerIndex = 1;
634-
std::vector<int32_t> base(2);
635-
for (int32_t i = 1, n = conversion->getInDimSize(kLane); i < n; i *= 2) {
636-
base[registerIndex] = i;
637-
registerBases.push_back(base);
638-
}
639-
}
640-
641-
std::vector<std::vector<int32_t>> laneBases(
642-
conversion->getInDimSizeLog2(kLane), std::vector<int32_t>{0, 0});
643-
std::array<std::pair<StringAttr, std::vector<std::vector<int32_t>>>, 2>
644-
bases{{{kRegister, std::move(registerBases)},
645-
{kLane, std::move(laneBases)}}};
646-
std::array<StringAttr, 2> outDimNames{kRegister, kLane};
647-
return conversion == LinearLayout(bases, outDimNames);
659+
// - register=2**N -> (2**(N-M), 0)
660+
// where out dims are: [register (size 2**(N - M)), lane (size 2**(M + 1))]
661+
//
662+
// With N >= M.
663+
int32_t registerInDimSize = conversion->getInDimSize(kRegister);
664+
int32_t laneOutDimSize = conversion->getOutDimSize(kLane);
665+
return conversion->getBases().lookup(kRegister) ==
666+
buildSubGroupShuffleRegisterBases(registerInDimSize, laneOutDimSize);
648667
}
649668

650669
bool isSupportedSubGroupShuffle(ConvertLayoutOp, OpAdaptor) const {
@@ -677,7 +696,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
677696

678697
SmallVector<Value> inVals =
679698
unpackLLElements(loc, adaptor.getSrc(), rewriter);
680-
assert(inVals.size() == 1 && "Expecting single element");
681699

682700
// TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR
683701
// upstream level. We are not enabling support for all types here as that
@@ -706,7 +724,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
706724
});
707725

708726
SmallVector<Value> outVals =
709-
performSubGroupShuffle(loc, inVals.front(), subGroupSize, rewriter);
727+
performSubGroupShuffle(loc, inVals, subGroupSize, rewriter);
710728

711729
// TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR
712730
// upstream level. We are not enabling support for all types here as that
@@ -737,16 +755,19 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
737755
}
738756

739757
SmallVector<Value>
740-
performSubGroupShuffle(Location loc, Value val, int32_t subGroupSize,
758+
performSubGroupShuffle(Location loc, ArrayRef<Value> inVals,
759+
int32_t subGroupSize,
741760
ConversionPatternRewriter &rewriter) const {
742761
SmallVector<Value> res;
743762
Value width = i32_val(subGroupSize);
744-
for (int32_t i = 0; i < subGroupSize; ++i)
745-
res.push_back(
746-
rewriter
747-
.create<mlir::gpu::ShuffleOp>(loc, val, i32_val(i), width,
748-
mlir::gpu::ShuffleMode::IDX)
749-
.getShuffleResult());
763+
for (Value val : inVals) {
764+
for (int32_t i = 0; i < subGroupSize; ++i)
765+
res.push_back(
766+
rewriter
767+
.create<mlir::gpu::ShuffleOp>(loc, val, i32_val(i), width,
768+
mlir::gpu::ShuffleMode::IDX)
769+
.getShuffleResult());
770+
}
750771
return res;
751772
}
752773

0 commit comments

Comments
 (0)