Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions test/Conversion/intel/sub-group-shuffle.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,106 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
tt.return %0 : tensor<32xf32, #sliced1>
}
}

// -----

// Case of more than one element per thread in the non-sliced dimension.

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
#sliced = #triton_gpu.slice<{dim = 1, parent = #blocked}>
#sliced1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: llvm.func spir_kernelcc @test(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f64, f64)>,
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f64, f64)>
// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(f64, f64)>
// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_5]])
// CHECK: %[[VAL_8:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_8]])
// CHECK: llvm.mlir.constant(true) : i1
// CHECK: %[[VAL_11:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_11]])
// CHECK: %[[VAL_14:.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_14]])
// CHECK: %[[VAL_17:.*]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_17]])
// CHECK: %[[VAL_20:.*]] = llvm.mlir.constant(5 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_20]])
// CHECK: %[[VAL_23:.*]] = llvm.mlir.constant(6 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_23]])
// CHECK: %[[VAL_26:.*]] = llvm.mlir.constant(7 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_26]])
// CHECK: %[[VAL_29:.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_29]])
// CHECK: %[[VAL_32:.*]] = llvm.mlir.constant(9 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_32]])
// CHECK: %[[VAL_35:.*]] = llvm.mlir.constant(10 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_35]])
// CHECK: %[[VAL_38:.*]] = llvm.mlir.constant(11 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_38]])
// CHECK: %[[VAL_41:.*]] = llvm.mlir.constant(12 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_41]])
// CHECK: %[[VAL_44:.*]] = llvm.mlir.constant(13 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_44]])
// CHECK: %[[VAL_47:.*]] = llvm.mlir.constant(14 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_47]])
// CHECK: %[[VAL_50:.*]] = llvm.mlir.constant(15 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_50]])
// CHECK: %[[VAL_53:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_53]])
// CHECK: %[[VAL_56:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_56]])
// CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_59]])
// CHECK: %[[VAL_62:.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_62]])
// CHECK: %[[VAL_65:.*]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_65]])
// CHECK: %[[VAL_68:.*]] = llvm.mlir.constant(5 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_68]])
// CHECK: %[[VAL_71:.*]] = llvm.mlir.constant(6 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_71]])
// CHECK: %[[VAL_74:.*]] = llvm.mlir.constant(7 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_74]])
// CHECK: %[[VAL_77:.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_77]])
// CHECK: %[[VAL_80:.*]] = llvm.mlir.constant(9 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_80]])
// CHECK: %[[VAL_83:.*]] = llvm.mlir.constant(10 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_83]])
// CHECK: %[[VAL_86:.*]] = llvm.mlir.constant(11 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_86]])
// CHECK: %[[VAL_89:.*]] = llvm.mlir.constant(12 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_89]])
// CHECK: %[[VAL_92:.*]] = llvm.mlir.constant(13 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_92]])
// CHECK: %[[VAL_95:.*]] = llvm.mlir.constant(14 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_95]])
// CHECK: %[[VAL_98:.*]] = llvm.mlir.constant(15 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_98]])
tt.func @test(%arg0: tensor<32xf64, #sliced>) -> tensor<32xf64, #sliced1> {
%0 = triton_gpu.convert_layout %arg0 : tensor<32xf64, #sliced> -> tensor<32xf64, #sliced1>
tt.return %0 : tensor<32xf64, #sliced1>
}
}

// -----

// Case of more than one element per thread and 2 warps in the non-sliced dimension.

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 1], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 1], order = [0, 1]}>
#sliced = #triton_gpu.slice<{dim = 1, parent = #blocked}>
#sliced1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: llvm.func spir_kernelcc @test
// CHECK-COUNT-64: llvm.call spir_funccc @_Z17sub_group_shuffleij
tt.func @test(%arg0: tensor<128xi32, #sliced>) -> tensor<128xi32, #sliced1> {
%0 = triton_gpu.convert_layout %arg0 : tensor<128xi32, #sliced> -> tensor<128xi32, #sliced1>
tt.return %0 : tensor<128xi32, #sliced1>
}
}
151 changes: 91 additions & 60 deletions third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,62 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
}

// Return a vector such as:
// [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [laneSize, 0], ...,
// [registerSize / 2, 0]],
// i.e., mapping registers to lanes till laneSize and performing an ID
// conversion afterwards.
static std::vector<std::vector<int32_t>>
buildSubGroupTransposeRegisterBases(int32_t registerSize, int32_t laneSize) {
std::vector<std::vector<int32_t>> bases;
std::vector<int32_t> curr(2);
for (int32_t i = 1; i < laneSize; i *= 2) {
curr[1] = i;
bases.push_back(curr);
}
curr[1] = 0;
for (int32_t i = laneSize; i < registerSize; i *= 2) {
curr[0] = i;
bases.push_back(curr);
}
return bases;
}

// Return a vector such as:
// [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [1, 0], ...,
// [registerSize / (2 * laneSize), 0]]
// i.e., mapping registers to lanes till laneSize and repeating the pattern
// afterwards.
static std::vector<std::vector<int32_t>>
buildSubGroupShuffleRegisterBases(int32_t registerSize, int32_t laneSize) {
std::vector<std::vector<int32_t>> bases;
std::vector<int32_t> curr(2);
for (int32_t i = 1; i < laneSize; i *= 2) {
curr[1] = i;
bases.push_back(curr);
}
curr[1] = 0;
for (int32_t i = laneSize, val = 1; i < registerSize; i *= 2, val *= 2) {
curr[0] = val;
bases.push_back(curr);
}
return bases;
}

// Return a vector such as:
// [[1, 0], [2, 0], [4, 0], ..., [laneSize / 2, 0]],
// i.e., mapping lanes to registers.
static std::vector<std::vector<int32_t>>
buildSubGroupTransposeLaneBases(int32_t laneSize) {
std::vector<std::vector<int32_t>> bases;
std::vector<int32_t> curr(2);
for (int32_t i = 1; i < laneSize; i *= 2) {
curr[0] = i;
bases.push_back(curr);
}
return bases;
}

bool isSubGroupTranspose(const LinearLayout &srcLayout,
const LinearLayout &dstLayout) const {
MLIRContext *ctx = srcLayout.getInDimNames().begin()->getContext();
Expand Down Expand Up @@ -476,35 +532,13 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))]
//
// With N >= M.
const auto buildBasis = [&](int32_t size, std::size_t index) {
std::vector<std::vector<int32_t>> basis;
std::vector<int32_t> curr(2);
for (int32_t i = 1; i < size; i *= 2) {
curr[index] = i;
basis.push_back(curr);
}
return basis;
};
constexpr std::size_t laneIndex = 0;
constexpr std::size_t registerIndex = 1;
int32_t laneSize = conversion->getInDimSize(kLane);
std::vector<std::vector<int32_t>> registerBases =
buildBasis(laneSize, registerIndex);
{
// Populate register bases for N > M.
std::vector<int32_t> base(2);
for (int32_t i = laneSize,
registerSize = conversion->getInDimSize(kRegister);
i < registerSize; i *= 2) {
base[laneIndex] = i;
registerBases.push_back(base);
}
}
std::array<std::pair<StringAttr, std::vector<std::vector<int32_t>>>, 2>
bases{{{kRegister, std::move(registerBases)},
{kLane, buildBasis(laneSize, laneIndex)}}};
std::array<StringAttr, 2> outDimNames{kRegister, kLane};
return conversion == LinearLayout(bases, outDimNames);
int32_t registerInDimSize = conversion->getInDimSize(kRegister);
int32_t laneInDimSize = conversion->getInDimSize(kLane);
return conversion->getBases().lookup(kRegister) ==
buildSubGroupTransposeRegisterBases(registerInDimSize,
laneInDimSize) &&
conversion->getBases().lookup(kLane) ==
buildSubGroupTransposeLaneBases(laneInDimSize);
Comment on lines +535 to +541
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor and protected against some illegal layout creation.

}

LogicalResult
Expand Down Expand Up @@ -619,32 +653,27 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// Expected conversion is:
// - register=1 -> (0, 1)
// ...
// register=i -> (0, i)
// - register=2**i -> (0, 2**i)
// ...
// - register=M -> (0, 2**M)
// ...
// - register=2**k -> (2**(k-M), 0)
// ...
// register=N -> (0, N)
// - register=2**N -> (2**(N-M), 0)
// - lane=1 -> (0, 0)
// ...
// lane=i -> (0, 0)
// - lane=2**j -> (0, 0)
// ...
// lane=N -> (0, 0)
// where out dims are: [register (size 1), lane (size N)]
std::vector<std::vector<int32_t>> registerBases;
{
constexpr std::size_t registerIndex = 1;
std::vector<int32_t> base(2);
for (int32_t i = 1, n = conversion->getInDimSize(kLane); i < n; i *= 2) {
base[registerIndex] = i;
registerBases.push_back(base);
}
}

std::vector<std::vector<int32_t>> laneBases(
conversion->getInDimSizeLog2(kLane), std::vector<int32_t>{0, 0});
std::array<std::pair<StringAttr, std::vector<std::vector<int32_t>>>, 2>
bases{{{kRegister, std::move(registerBases)},
{kLane, std::move(laneBases)}}};
std::array<StringAttr, 2> outDimNames{kRegister, kLane};
return conversion == LinearLayout(bases, outDimNames);
// lane=2**M -> (0, 0)
// where out dims are: [register (size 2**(N - M)), lane (size 2**(M + 1))]
//
// With N >= M.
int32_t registerInDimSize = conversion->getInDimSize(kRegister);
int32_t laneOutDimSize = conversion->getOutDimSize(kLane);
return conversion->sublayoutIsZero({kLane}, {kRegister, kLane}) &&
conversion->getBases().lookup(kRegister) ==
buildSubGroupShuffleRegisterBases(registerInDimSize,
laneOutDimSize);
Comment on lines +674 to +676
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will investigate whether there is a better way to express this in the future when generalizing. Same for transpose case.

}

bool isSupportedSubGroupShuffle(ConvertLayoutOp, OpAdaptor) const {
Expand Down Expand Up @@ -674,7 +703,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion

SmallVector<Value> inVals =
unpackLLElements(loc, adaptor.getSrc(), rewriter);
assert(inVals.size() == 1 && "Expecting single element");

// TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR
// upstream level. We are not enabling support for all types here as that
Expand Down Expand Up @@ -703,7 +731,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
});

SmallVector<Value> outVals =
performSubGroupShuffle(loc, inVals.front(), subGroupSize, rewriter);
performSubGroupShuffle(loc, inVals, subGroupSize, rewriter);

// TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR
// upstream level. We are not enabling support for all types here as that
Expand Down Expand Up @@ -734,16 +762,19 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
}

SmallVector<Value>
performSubGroupShuffle(Location loc, Value val, int32_t subGroupSize,
performSubGroupShuffle(Location loc, ArrayRef<Value> inVals,
int32_t subGroupSize,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> res;
Value width = i32_val(subGroupSize);
for (int32_t i = 0; i < subGroupSize; ++i)
res.push_back(
rewriter
.create<mlir::gpu::ShuffleOp>(loc, val, i32_val(i), width,
mlir::gpu::ShuffleMode::IDX)
.getShuffleResult());
for (Value val : inVals) {
for (int32_t i = 0; i < subGroupSize; ++i)
res.push_back(
rewriter
.create<mlir::gpu::ShuffleOp>(loc, val, i32_val(i), width,
mlir::gpu::ShuffleMode::IDX)
.getShuffleResult());
}
return res;
}

Expand Down