Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 213 additions & 29 deletions mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,19 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include <cstddef>
#include <utility>

using namespace mlir;
Expand Down Expand Up @@ -939,8 +945,40 @@ struct WarpOpForwardOperand : public WarpDistributionPattern {
}
};

static VectorType
tryFindDistributedType(TypedValue<VectorType> source,
WarpExecuteOnLane0Op warpOp,
const DistributionMapFn &distributionMapFn) {
VectorType distributedType = source.getType();
// Check if the source is yielded from the warp op.
gpu::YieldOp yieldOp = cast<gpu::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
auto *it = llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) {
return operand.get() == source;
});

if (it != yieldOp->getOpOperands().end()) {
// If the source is yielded from the warp op, we can use the matching
// warp result type as the distributed source type.
distributedType =
cast<VectorType>(warpOp->getResultTypes()[it->getOperandNumber()]);
} else {
// If the source is not yielded from the warp op, we need to compute
// the distributed source type based on the distribution map and the
// warp size.
AffineMap map = distributionMapFn(source);
VectorType computed =
getDistributedType(source.getType(), map, warpOp.getWarpSize());
if (!computed)
return source.getType();
distributedType = computed;
}
return distributedType;
}

struct WarpOpBroadcast : public WarpDistributionPattern {
using Base::Base;
WarpOpBroadcast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
Expand All @@ -953,68 +991,107 @@ struct WarpOpBroadcast : public WarpDistributionPattern {
auto destVecType =
cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
Value broadcastSrc = broadcastOp.getSource();
Type broadcastSrcType = broadcastSrc.getType();
Type srcDistributedType = broadcastSrc.getType();

if (isa<VectorType>(srcDistributedType))
srcDistributedType =
tryFindDistributedType(cast<TypedValue<VectorType>>(broadcastSrc),
warpOp, distributionMapFn);

// Check that the broadcast actually spans a set of values uniformly across
// all threads. In other words, check that each thread can reconstruct
// their own broadcast.
// For that we simply check that the broadcast we want to build makes sense.
if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
if (vector::isBroadcastableTo(srcDistributedType, destVecType) !=
vector::BroadcastableToResult::Success)
return failure();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
rewriter, warpOp, {broadcastSrc}, {srcDistributedType}, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value broadcasted = vector::BroadcastOp::create(
rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
broadcasted);
return success();
}

private:
DistributionMapFn distributionMapFn;
};

/// Pattern to move shape cast out of the warp op. shape cast is basically a
/// no-op for warp distribution; we need to handle the shape though.
struct WarpOpShapeCast : public WarpDistributionPattern {
using Base::Base;

WarpOpShapeCast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
if (!operand)
return failure();

auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();

unsigned int operandNumber = operand->getOperandNumber();
auto castDistributedType =
VectorType sourceType = oldCastOp.getSourceVectorType();
VectorType distributedResultType =
cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
VectorType castOriginalType = oldCastOp.getSourceVectorType();
VectorType castResultType = castDistributedType;

// We expect the distributed type to have a smaller rank than the original
// type. Prepend with size-one dimensions to make them the same.
unsigned castDistributedRank = castDistributedType.getRank();
unsigned castOriginalRank = castOriginalType.getRank();
if (castDistributedRank < castOriginalRank) {
SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
llvm::append_range(shape, castDistributedType.getShape());
castDistributedType =
VectorType::get(shape, castDistributedType.getElementType());
VectorType distributedSourceType = sourceType;
bool isResultDistributed = distributedResultType.getNumElements() <
oldCastOp.getResultVectorType().getNumElements();

// If the result is not distributed, source distribted type is the same
// as the source type. If the result is distributed, we need to compute the
// distributed source type according to following rules:
// 1. If the source type is yielded from the warp op, we can use the
// matching warp result type as the distributed source type.
// 2. If the source type is not yielded from the warp op, we need
// to compute the distributed source type based on the distribution map
// and the warp size.
if (isResultDistributed) {
// Check if the source is yielded from the warp op.
gpu::YieldOp yieldOp = cast<gpu::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
auto *it =
llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) {
return operand.get() == oldCastOp.getSource();
});

if (it != yieldOp->getOpOperands().end()) {
// If the source is yielded from the warp op, we can use the matching
// warp result type as the distributed source type.
distributedSourceType =
cast<VectorType>(warpOp->getResultTypes()[it->getOperandNumber()]);
} else {
// If the source is not yielded from the warp op, we need to compute
// the distributed source type based on the distribution map and the
// warp size.
AffineMap map = distributionMapFn(oldCastOp.getSource());
distributedSourceType =
getDistributedType(sourceType, map, warpOp.getWarpSize());
if (!distributedSourceType)
return rewriter.notifyMatchFailure(
oldCastOp,
"cannot compute distributed source type for shape cast");
}
}

SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
rewriter, warpOp, {oldCastOp.getSource()}, {distributedSourceType},
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value newCast = vector::ShapeCastOp::create(
rewriter, oldCastOp.getLoc(), castResultType,
rewriter, oldCastOp.getLoc(), distributedResultType,
newWarpOp->getResult(newRetIndices[0]));
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
return success();
}

private:
DistributionMapFn distributionMapFn;
};

/// Sink out vector.create_mask op feeding into a warp op yield.
Expand Down Expand Up @@ -1995,6 +2072,114 @@ struct WarpOpReduction : public WarpDistributionPattern {
DistributedReductionFn distributedReductionFn;
};

struct VectorMultiDimReductionDistribution : public WarpDistributionPattern {
VectorMultiDimReductionDistribution(MLIRContext *context,
PatternBenefit benefit = 1)
: WarpDistributionPattern(context, benefit) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *yieldOperand =
getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
if (!yieldOperand)
return failure();
auto reductionOp =
cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
unsigned operandNumber = yieldOperand->getOperandNumber();
VectorType sourceType = reductionOp.getSourceVectorType();
VectorType distributedResultType =
cast<VectorType>(warpOp.getResult(operandNumber).getType());
Type elementType = distributedResultType.getElementType();
// Only 2D vectors are supported.
if (sourceType.getRank() != 2)
return rewriter.notifyMatchFailure(warpOp,
"Only 2D reductions are supported.");
ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
// Only 1 reduction dimension supported.
if (reductionDims.size() != 1)
return rewriter.notifyMatchFailure(
warpOp, "Only 1 reduction dimension is supported.");

// Col reduction.
if (reductionDims[0] == 0) {
// Yield the source vector and the accumulator.
if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0)
return rewriter.notifyMatchFailure(
warpOp, "Source vector dimension must be divisible by warp size.");
SmallVector<int64_t> shape(sourceType.getShape());
shape[1] = shape[1] / warpOp.getWarpSize();
auto sourceDistributedType = VectorType::get(shape, elementType);
SmallVector<size_t> newRetIndices;
auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
{sourceDistributedType, distributedResultType}, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
// Create new reduction op.
// auto newOp = vector::MultiDimReductionOp::create(
// rewriter, reductionOp.getLoc(), distributedResultType,
// reductionOp.getKind(),
// /** source = **/ newWarpOp.getResult(newRetIndices[0]),
// /** accumulator = **/ newWarpOp.getResult(newRetIndices[1]),
// reductionDims);
// Create a constant zero value for storing the reduction result.
// rewriter.setInsertionPointAfter(reductionOp);
auto zeroAttr =
rewriter.getZeroAttr(distributedResultType.getElementType());
Value result = arith::ConstantOp::create(
rewriter, reductionOp->getLoc(), distributedResultType,
DenseElementsAttr::get(distributedResultType, zeroAttr));
int nCols = sourceDistributedType.getShape()[1];
Value source = newWarpOp.getResult(newRetIndices[0]);
Value acc = newWarpOp.getResult(newRetIndices[1]);
for (int i = 0; i < nCols; ++i) {
Value col = vector::ExtractStridedSliceOp::create(
rewriter, reductionOp.getLoc(), source, {0, i},
{sourceDistributedType.getShape()[0], 1}, {1, 1});
col = vector::ShapeCastOp::create(
rewriter, reductionOp.getLoc(),
VectorType::get({sourceDistributedType.getShape()[0]}, elementType),
col);
Value accCol =
vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i);
Value colReduce = vector::ReductionOp::create(
rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol);
// Insert the reduced column into the result.
result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
colReduce, result, i);
}
// Replace the warp op result with the new reduction op.
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result);
return success();
}
// Row reduction.
// Create a constant zero value for storing the reduction result.
rewriter.setInsertionPointAfter(reductionOp);
auto zeroAttr =
rewriter.getZeroAttr(distributedResultType.getElementType());
Value result = arith::ConstantOp::create(
rewriter, reductionOp->getLoc(), distributedResultType,
DenseElementsAttr::get(distributedResultType, zeroAttr));
// Value result = arith::ConstantOp::create(
// rewriter, reductionOp.getLoc(),
// rewriter.getIntegerAttr(reductionOp.getType(), 0));
int nRows = sourceType.getShape()[0];
// For each row, do a vector reduction.
for (int i = 0; i < nRows; ++i) {
Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
reductionOp.getSource(), i);
Value acc = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
reductionOp.getAcc(), i);
Value rowReduce = vector::ReductionOp::create(
rewriter, reductionOp.getLoc(), reductionOp.getKind(), source, acc);
result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
rowReduce, result, i);
}
// Replace the warp op result with the final result.
rewriter.replaceAllUsesWith(reductionOp.getResult(), result);

return success();
}
};

} // namespace

void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
Expand All @@ -2015,16 +2200,15 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
PatternBenefit readBenefit) {
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
patterns
.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(
patterns.getContext(), benefit);
patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpExtract,
WarpOpForwardOperand, WarpOpConstant, WarpOpInsertScalar,
WarpOpInsert, WarpOpCreateMask, WarpOpExtractStridedSlice,
WarpOpInsertStridedSlice, VectorMultiDimReductionDistribution>(
patterns.getContext(), benefit);
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
benefit);
patterns.add<WarpOpScfForOp, WarpOpShapeCast, WarpOpBroadcast>(
patterns.getContext(), distributionMapFn, benefit);
}

void mlir::vector::populateDistributeReduction(
Expand Down
Loading