Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions test/Conversion/intel/intel-allocate-shared-memory.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,24 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>

// Check no scratch memory is allocated for sub-group shuffle-like layout conversions.

// CHECK-LABEL: module attributes
// CHECK-SAME: triton_gpu.shared = 0 : i32
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK: tt.func @test_sub_group_shuffle
// CHECK-NOT: llvm.ptr<3>
tt.func @test_sub_group_shuffle(%arg0: tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> {
%0 = triton_gpu.convert_layout %arg0 : tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
tt.return %0 : tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>

Expand Down
35 changes: 35 additions & 0 deletions test/Conversion/intel/sub-group-shuffle.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,38 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 :
tt.return %0 : tensor<128xi32, #sliced1>
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>

// Case of more than one contiguous element per work-item.

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: llvm.func spir_kernelcc @test_contiguous(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f16, f16)>)
tt.func @test_contiguous(%arg0: tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> {
// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f16, f16)>
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(f16, f16)>
// COM: Check the shuffles are "coalesced"
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah would be nice if CHECK-COUNT could work with more than one line (so we could check that a pattern involving more than one line repeats a specified number of times). AFAIK this is not possible though.

// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_1]]
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]]
%0 = triton_gpu.convert_layout %arg0 : tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
tt.return %0 : tensor<32xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
}
}
57 changes: 52 additions & 5 deletions third_party/intel/lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,29 @@ buildSubGroupShuffleRegisterBases(int32_t registerSize, int32_t laneSize) {
return bases;
}

// Return a vector such as:
// [[1, 0], [2, 0], [4, 0], ..., [registerSize / laneSize, 0], [0, 1], ...,
// [0, laneSize/2]]
// i.e., mapping registers to registers till registerSize / laneSize (all
// contiguous registers) and then to lanes.
std::vector<std::vector<int32_t>>
buildContiguousSubGroupShuffleRegisterBases(int32_t registerSize,
int32_t laneSize) {
std::vector<std::vector<int32_t>> bases;
std::vector<int32_t> curr(2);
int i = 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[NIT]: int -> int32_t (consistency with surrounding code)

for (; i < registerSize / laneSize; i *= 2) {
curr[0] = i;
bases.push_back(curr);
}
curr[0] = 0;
for (int32_t val = 1; i < registerSize; i *= 2, val *= 2) {
curr[1] = val;
bases.push_back(curr);
}
return bases;
}

// Return a vector such as:
// [[1, 0], [2, 0], [4, 0], ..., [laneSize / 2, 0]],
// i.e., mapping lanes to registers.
Expand Down Expand Up @@ -138,25 +161,49 @@ bool cvtIsSubGroupShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
// ...
// - register=2**i -> (0, 2**i)
// ...
// - register=M -> (0, 2**M)
// - register=M -> (0, 2**(M-1))
// - register=M+1 -> (1, 0)
// ...
// - register=2**k -> (2**(k-M), 0)
// - register=2**k -> (2**(K-M), 0)
// ...
// - register=2**N -> (2**(N-M), 0)
// - lane=1 -> (0, 0)
// ...
// - lane=2**j -> (0, 0)
// ...
// lane=2**M -> (0, 0)
// where out dims are: [register (size 2**N), lane (size 2**M)]
//
// With N >= M.
//
// Or, when the elements managed by a given work-item are in contiguous
// positions:
// - register=1 -> (1, 0)
// ...
// - register=2**i -> (2**i, 0)
// ...
// - register=M -> (2**(N - M), 0)
// ...
// - register=2**k -> (0, 1)
// ...
// - register=2**N -> (0, 2**(M-1))
// - lane=1 -> (0, 0)
// ...
// - lane=2**j -> (0, 0)
// ...
// lane=2**M -> (0, 0)
// where out dims are: [register (size 2**(N - M)), lane (size 2**(M + 1))]
//
// With N >= M.
int32_t registerInDimSize = conversion->getInDimSize(kRegister);
int32_t laneOutDimSize = conversion->getOutDimSize(kLane);
return conversion->sublayoutIsZero({kLane}, {kRegister, kLane}) &&
conversion->getBases().lookup(kRegister) ==
buildSubGroupShuffleRegisterBases(registerInDimSize,
laneOutDimSize);
(conversion->getBases().lookup(kRegister) ==
buildSubGroupShuffleRegisterBases(registerInDimSize,
laneOutDimSize) ||
conversion->getBases().lookup(kRegister) ==
buildContiguousSubGroupShuffleRegisterBases(registerInDimSize,
laneOutDimSize));
}

bool isValidElementTypeForSubGroupTranspose(Type type) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,24 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
return success();
}

int getNumContiguousRowsForShuffle(const LinearLayout &srcLayout,
const LinearLayout &dstLayout) const {
MLIRContext *ctx = getContext();

StringAttr kRegister = str_attr("register");
StringAttr kLane = str_attr("lane");
StringAttr kWarp = str_attr("warp");
StringAttr kBlock = str_attr("block");
LinearLayout comp =
*dstLayout.invertAndCompose(srcLayout).quotient({kWarp, kBlock});
// Basic case: the number of contiguous rows is 1.
if (comp.getBasis(kRegister, 0)[1] == 1)
return 1;
// In other case, we only allow all threads handled by a single element to
// be contiguous, so we can simply:
return comp.getOutDimSize(kRegister);
}

void performSubGroupShuffle(ConvertLayoutOp op, const LinearLayout &srcLayout,
const LinearLayout &dstLayout, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Expand Down Expand Up @@ -605,8 +623,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
});
});

SmallVector<Value> outVals =
performSubGroupShuffle(loc, inVals, subGroupSize, rewriter);
SmallVector<Value> outVals = performSubGroupShuffle(
loc, inVals, subGroupSize, rewriter,
getNumContiguousRowsForShuffle(srcLayout, dstLayout));

// TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR
// upstream level. We are not enabling support for all types here as that
Expand Down Expand Up @@ -636,19 +655,41 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
rewriter.replaceOp(op, result);
}

SmallVector<Value>
performSubGroupShuffle(Location loc, ArrayRef<Value> inVals,
int32_t subGroupSize,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> performSubGroupShuffle(Location loc,
ArrayRef<Value> inVals,
int32_t subGroupSize,
ConversionPatternRewriter &rewriter,
int numContiguousRows) const {
SmallVector<Value> res;
Value width = i32_val(subGroupSize);
for (Value val : inVals) {
for (int32_t i = 0; i < subGroupSize; ++i)
res.push_back(
rewriter
.create<mlir::gpu::ShuffleOp>(loc, val, i32_val(i), width,
mlir::gpu::ShuffleMode::IDX)
.getShuffleResult());
// A work-item may handle more than one element. There are two cases we
// support:
if (numContiguousRows == 1) {
// 1. Elements held by a work-item are strided rows in the abstract slice
// matrix: Output element `i` will take the `i / 16`th value from the `i %
// 16`th thread.
for (Value val : inVals) {
for (int32_t i = 0; i < subGroupSize; ++i) {
res.push_back(
rewriter
.create<mlir::gpu::ShuffleOp>(loc, val, i32_val(i), width,
mlir::gpu::ShuffleMode::IDX)
.getShuffleResult());
}
}
} else {
// 2. Elements held by a work-item are contiguous rows in the abstract
// slice matrix: Output element `i` will take the `i % 16`th value from
// the `i / 16`th thread.
for (int32_t i = 0; i < subGroupSize; ++i) {
for (Value val : inVals) {
res.push_back(
rewriter
.create<mlir::gpu::ShuffleOp>(loc, val, i32_val(i), width,
mlir::gpu::ShuffleMode::IDX)
.getShuffleResult());
}
}
}
return res;
}
Expand Down