Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 244 additions & 3 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"

namespace mlir {
namespace xegpu {
Expand Down Expand Up @@ -174,6 +175,21 @@ static bool requireTranspose(const xegpu::LayoutAttr layout,
return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1;
}

/// Given a sequential and distributed vector type, return the list of
/// dimensions that are distributed.
static SmallVector<int64_t> getDistributedDims(VectorType sequentialType,
VectorType distributedType) {
assert(sequentialType.getRank() == distributedType.getRank() &&
"sequential and distributed vector types must have the same rank");
SmallVector<int64_t> distributedDims;
for (int64_t i = 0; i < sequentialType.getRank(); ++i) {
if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
distributedDims.push_back(i);
}
}
return distributedDims;
}

/// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
/// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
/// contained within a WarpExecuteOnLane0Op.
Expand Down Expand Up @@ -1469,6 +1485,227 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
}
};

// Distribute a `vector.extract_strided_slice` op feeding into yield op of an
// enclosing `gpu.warp_execute_on_lane_0` region. This pattern only handles
// advanced cases where the distributed is partially extracted and currently not
// supported by the generic vector distribution patterns.
struct VectorExtractStridedSliceDistribution
: public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
if (!operand)
return failure();
auto extractOp =
cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp());
unsigned operandIdx = operand->getOperandNumber();
auto distributedType =
cast<VectorType>(warpOp.getResult(operandIdx).getType());
// Find the distributed dimension. There should be exactly one.
auto extractResultType = cast<VectorType>(operand->get().getType());
auto distributedDims =
getDistributedDims(extractResultType, distributedType);
// Collect updated source type, sizes and offsets. They may be adjusted
// later if the data is distributed to lanes (as opposed to being owned by
// all lanes uniformly).
VectorType updatedSourceType = extractOp.getSourceVectorType();
SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
extractOp.getSizes(), [](Attribute attr) { return attr; });
SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
extractOp.getOffsets(), [](Attribute attr) { return attr; });
// If the result is distributed, it must be distributed in exactly one
// dimension. In this case, we adjust the sourceDistType, distributedSizes
// and distributedOffsets accordingly.
if (distributedDims.size() > 0) {
if (distributedDims.size() != 1)
return rewriter.notifyMatchFailure(
warpOp, "Source can not be distributed in multiple dimensions.");
int64_t distributedDim = distributedDims[0];
int sourceDistrDimSize =
extractOp.getSourceVectorType().getShape()[distributedDim];
auto sourceLayout =
xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0));
if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
return rewriter.notifyMatchFailure(
warpOp, "the source of extract_strided_slice op lacks distribution "
"layout");
auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
// Because only single dimension distribution is supported, lane layout
// size at the distributed dim must be the subgroup size.
int subgroupSize = sourceLaneLayout[distributedDim];
// Check if the source size in the distributed dimension is a multiple of
// subgroup size.
if (sourceDistrDimSize % subgroupSize != 0)
return rewriter.notifyMatchFailure(
warpOp,
"Source size along distributed dimension is not a multiple of "
"subgroup size.");
auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
// We expect lane data to be all ones in this case.
if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
return rewriter.notifyMatchFailure(
warpOp, "Expecting unit lane data in source layout");
// The offsets in the distributed dimention must be a multiple of subgroup
// size.
int64_t distrDimOffset =
cast<IntegerAttr>(extractOp.getOffsets()[distributedDim]).getInt();
if (distrDimOffset % subgroupSize != 0)
return rewriter.notifyMatchFailure(
warpOp, "Offset along distributed dimension "
"is not a multiple of subgroup size.");
updatedSourceType = getDistVecTypeBasedOnLaneLayout(
sourceLayout, extractOp.getSourceVectorType())
.value();
// Update the distributed sizes to match the distributed type.
updatedSizes[distributedDim] = rewriter.getI64IntegerAttr(
distributedType.getDimSize(distributedDim));
// Update the distributed offsets to match round robin distribution (i.e.
// each lane owns data at `subgroupSize` stride given unit lane data).
updatedOffsets[distributedDim] =
rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
}
// Do the distribution by yielding the source of the extract op from
// the warp op and creating a new extract op outside the warp op.
SmallVector<size_t> newRetIndices;
auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value source = newWarpOp.getResult(newRetIndices[0]);
// Create a new extract op outside the warp op.
Value newExtractOp = vector::ExtractStridedSliceOp::create(
rewriter, extractOp.getLoc(), distributedType, source,
ArrayAttr::get(rewriter.getContext(), updatedOffsets),
ArrayAttr::get(rewriter.getContext(), updatedSizes),
extractOp.getStrides());
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp);
return success();
}
};

/// Distribute a `vector.insert_strided_slice` op feeding into yield op of an
/// enclosing `gpu.warp_execute_on_lane_0` region. This pattern only handles
/// advanced cases where the distributed dimension is partially inserted and
/// currently not supported by the generic vector distribution patterns.
struct VectorInsertStridedSliceDistribution
: public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
auto insertOp =
operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
auto distributedType =
cast<VectorType>(warpOp.getResult(operandNumber).getType());
// Find the distributed dimension of the dest vector. There should be
// exactly one.
auto insertResultType = cast<VectorType>(operand->get().getType());
auto destDistributedDims =
getDistributedDims(insertResultType, distributedType);
// Collect updated offsets, source type and dest type. They may be adjusted
// later if the data is distributed to lanes (as opposed to being owned by
// all lanes uniformly).
SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
insertOp.getOffsets(), [](Attribute attr) { return attr; });
VectorType updatedSourceType = insertOp.getSourceVectorType();
VectorType updatedDestType = insertOp.getDestVectorType();
if (destDistributedDims.size() > 0) {
// Only single dimension distribution is supported.
if (destDistributedDims.size() != 1)
return rewriter.notifyMatchFailure(
warpOp,
"Expecting source to be distributed in a single dimension.");
int64_t destDistributedDim = destDistributedDims[0];

VectorType srcType = insertOp.getSourceVectorType();
VectorType destType = insertOp.getDestVectorType();
// Currently we require that both source (kD) and dest (nD) vectors are
// distributed. This requires that distributedDim (d) is contained in the
// last k dims of the dest vector (d >= n - k).
int64_t sourceDistributedDim =
destDistributedDim - (destType.getRank() - srcType.getRank());
if (sourceDistributedDim < 0)
return rewriter.notifyMatchFailure(
insertOp,
"distributed dimension must be in the last k (i.e. source "
"rank) dims of dest vector");
int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
// Obtain the source and dest layouts.
auto destLayout =
xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(1));
auto sourceLayout =
xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(0));
if (!destLayout || !sourceLayout ||
destLayout.getEffectiveLaneLayoutAsInt().empty() ||
sourceLayout.getEffectiveLaneLayoutAsInt().empty())
return rewriter.notifyMatchFailure(
warpOp, "the source or dest of insert_strided_slice op lacks "
"distribution layout");
// Because only single dimension distribution is supported, lane layout
// size at the distributed dim must be the subgroup size.
int subgroupSize =
destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
// We require that source and dest lane data are all ones to ensure
// uniform round robin distribution.
auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) ||
!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
return rewriter.notifyMatchFailure(
warpOp, "Expecting unit lane data in source and dest layouts");
// Source distributed dim size must be multiples of subgroup size.
if (srcDistrDimSize % subgroupSize != 0)
return rewriter.notifyMatchFailure(
warpOp, "Distributed dimension size in source is not a multiple of "
"subgroup size.");
// Offsets in the distributed dimension must be multiples of subgroup
// size.
int64_t destDistrDimOffset =
cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
if (destDistrDimOffset % subgroupSize != 0)
return rewriter.notifyMatchFailure(
warpOp,
"Offset along distributed dimension in dest is not a multiple of "
"subgroup size.");
// Update the source and dest types based on their layouts.
updatedSourceType = getDistVecTypeBasedOnLaneLayout(
sourceLayout, insertOp.getSourceVectorType())
.value();
updatedDestType = getDistVecTypeBasedOnLaneLayout(
destLayout, insertOp.getDestVectorType())
.value();
// Update the distributed offsets to match round robin distribution (i.e.
// each lane owns data at `subgroupSize` stride given unit lane data).
updatedOffsets[destDistributedDim] =
rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
}
// Do the distribution by yielding the source and dest of the insert op
// from the warp op and creating a new insert op outside the warp op.
SmallVector<size_t> newRetIndices;
auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
{updatedSourceType, updatedDestType}, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);

Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
Value dest = newWarpOp.getResult(newRetIndices[1]);
// Create a new insert op outside the warp op.
Value newInsertOp = vector::InsertStridedSliceOp::create(
rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
ArrayAttr::get(rewriter.getContext(), updatedOffsets),
insertOp.getStrides());
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
newInsertOp);
return success();
}
};

/// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
/// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
/// outside of the warp op.
Expand Down Expand Up @@ -1626,9 +1863,13 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
MemrefExtractAlignedPointerAsIndexDistribution>(
patterns.getContext(),
/*pattern benefit=*/regularPatternBenefit);
patterns.add<VectorShapeCastDistribution>(
patterns.getContext(),
/*pattern benefit=*/highPatternBenefit);
// For following patterns, we need to override the regular vector distribution
// patterns. Therefore, assign higher benefit.
patterns
.add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
VectorInsertStridedSliceDistribution>(
patterns.getContext(),
/*pattern benefit=*/highPatternBenefit);
}

void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(
Expand Down
Loading