Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1692,6 +1692,13 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
const ControlBlockPackMatmulFn &controlFn);

/// Adds patterns that that reduce the rank of named contraction ops that have
/// unit dimensions in the operand(s) by converting to a senquence of `collapse_shape`,
/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For example a
/// `linalg.batch_matmul` with unit batch size will convert to `linalg.matmul`
/// and a `linalg.matvec` with with unit spatial dim in lhs will convert to a `linalg.dot`.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);

} // namespace linalg
} // namespace mlir

Expand Down
241 changes: 241 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -833,4 +833,245 @@ struct LinalgFoldUnitExtentDimsPass
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
}
};

} // namespace

namespace {

static SmallVector<ReassociationIndices>
getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
auto lastDim = pos == rank - 1;
if (rank > 2) {
for (int64_t i = 0; i < rank - 1; i++) {
if (i == pos || (lastDim && i == pos - 1))
reassociation[i] = ReassociationIndices{i, i + 1};
else if (i < pos)
reassociation[i] = ReassociationIndices{i};
else
reassociation[i] = ReassociationIndices{i + 1};
}
}
return reassociation;
}

static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
int64_t pos) {
if (pos < 0)
return val;
auto valType = cast<ShapedType>(val.getType());
SmallVector<int64_t> collapsedShape(valType.getShape());
collapsedShape.erase(collapsedShape.begin() + pos);
return collapseValue(
rewriter, val.getLoc(), val, collapsedShape,
getReassociationForReshapeAtDim(valType.getRank(), pos),
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
}

template <typename FromOpTy, typename ToOpTy>
struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
using OpRewritePattern<FromOpTy>::OpRewritePattern;

SmallVector<Value, 3>
collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
ArrayRef<int64_t> operandCollapseDims) const {
assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
"expected 3 operands and dims");
return llvm::to_vector(llvm::map_range(
llvm::zip(operands, operandCollapseDims), [&](auto pair) {
return collapseSingletonDimAt(rewriter, std::get<0>(pair),
std::get<1>(pair));
}));
}

Value expandResult(PatternRewriter &rewriter, Value result,
RankedTensorType expandedType, int64_t dim) const {
return rewriter.create<tensor::ExpandShapeOp>(
result.getLoc(), expandedType, result,
getReassociationForReshapeAtDim(expandedType.getRank(), dim));
}

LogicalResult matchAndRewrite(FromOpTy contractionOp,
PatternRewriter &rewriter) const override {

auto loc = contractionOp.getLoc();
auto inputs = contractionOp.getDpsInputs();
auto inits = contractionOp.getDpsInits();
if (inputs.size() != 2 || inits.size() != 1)
return rewriter.notifyMatchFailure(contractionOp,
"expected 2 inputs and 1 init");
auto lhs = inputs[0];
auto rhs = inputs[1];
auto init = inits[0];
SmallVector<Value> operands{lhs, rhs, init};

auto maybeContractionDims = inferContractionDims(contractionOp);
if (failed(maybeContractionDims))
return rewriter.notifyMatchFailure(contractionOp,
"could not infer contraction dims");

auto contractionDims = maybeContractionDims.value();
SmallVector<int64_t> operandUnitDims;
if (failed(getOperandUnitDims(contractionOp, operandUnitDims)))
return rewriter.notifyMatchFailure(contractionOp,
"no reducable dims found");

auto collapsedOperands =
collapseOperands(rewriter, operands, operandUnitDims);
auto collapsedLhs = collapsedOperands[0];
auto collapsedRhs = collapsedOperands[1];
auto collapsedInit = collapsedOperands[2];
SmallVector<Type, 1> collapsedResultTy;
if (isa<RankedTensorType>(collapsedInit.getType()))
collapsedResultTy.push_back(collapsedInit.getType());
auto collapsedOp = rewriter.create<ToOpTy>(
loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
ValueRange{collapsedInit});
for (auto attr : contractionOp->getAttrs()) {
if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
continue;
collapsedOp->setAttr(attr.getName(), attr.getValue());
}

auto results = contractionOp.getResults();
assert(results.size() < 2 && "expected at most one result");
if (results.size() < 1)
rewriter.replaceOp(contractionOp, collapsedOp);
else
rewriter.replaceOp(
contractionOp,
expandResult(rewriter, collapsedOp.getResultTensors()[0],
cast<RankedTensorType>(results[0].getType()),
operandUnitDims[2]));

return success();
}

virtual LogicalResult
getOperandUnitDims(LinalgOp op,
SmallVectorImpl<int64_t> &operandUnitDindices) const = 0;
};

template <typename FromOpTy, typename ToOpTy>
struct RankReduceBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;

LogicalResult getOperandUnitDims(
LinalgOp op,
SmallVectorImpl<int64_t> &operandUnitDindices) const override {
auto inputs = op.getDpsInputs();
auto inits = op.getDpsInits();
if (inputs.size() != 2 || inits.size() != 1)
return failure();

auto maybeContractionDims = inferContractionDims(op);
if (failed(maybeContractionDims))
return failure();
auto contractionDims = maybeContractionDims.value();

if (contractionDims.batch.size() != 1)
return failure();
auto batchDim = contractionDims.batch[0];
SmallVector<std::pair<Value, unsigned>, 2> bOperands;
op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) {
return cast<ShapedType>(std::get<0>(pair).getType())
.getShape()[std::get<1>(pair)] != 1;
}))
return failure();

operandUnitDindices = SmallVector<int64_t>{std::get<1>(bOperands[0]),
std::get<1>(bOperands[1]),
std::get<1>(bOperands[2])};
return success();
}
};

template <typename FromOpTy, typename ToOpTy>
struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;

static bool constexpr reduceLeft =
(std::is_same<FromOpTy, BatchMatmulOp>::value &&
std::is_same<ToOpTy, BatchVecmatOp>::value) ||
(std::is_same<FromOpTy, BatchMatmulTransposeAOp>::value &&
std::is_same<ToOpTy, BatchVecmatOp>::value) ||
(std::is_same<FromOpTy, MatmulOp>::value &&
std::is_same<ToOpTy, VecmatOp>::value) ||
(std::is_same<FromOpTy, MatmulTransposeAOp>::value &&
std::is_same<ToOpTy, VecmatOp>::value) ||
(std::is_same<FromOpTy, MatvecOp>::value &&
std::is_same<ToOpTy, DotOp>::value);

LogicalResult getOperandUnitDims(
LinalgOp op,
SmallVectorImpl<int64_t> &operandUnitDindices) const override {
auto maybeContractionDims = inferContractionDims(op);
if (failed(maybeContractionDims))
return failure();
auto contractionDims = maybeContractionDims.value();

if constexpr (reduceLeft) {
auto m = contractionDims.m[0];
SmallVector<std::pair<Value, unsigned>, 2> mOperands;
op.mapIterationSpaceDimToAllOperandDims(m, mOperands);
if (mOperands.size() != 2)
return failure();
if (llvm::all_of(mOperands, [](auto pair) {
return cast<ShapedType>(std::get<0>(pair).getType())
.getShape()[std::get<1>(pair)] == 1;
})) {
operandUnitDindices = SmallVector<int64_t>{
std::get<1>(mOperands[0]), -1, std::get<1>(mOperands[1])};
return success();
}
} else {
auto n = contractionDims.n[0];
SmallVector<std::pair<Value, unsigned>, 2> nOperands;
op.mapIterationSpaceDimToAllOperandDims(n, nOperands);
if (nOperands.size() != 2)
return failure();
if (llvm::all_of(nOperands, [](auto pair) {
return cast<ShapedType>(std::get<0>(pair).getType())
.getShape()[std::get<1>(pair)] == 1;
})) {
operandUnitDindices = SmallVector<int64_t>{
-1, std::get<1>(nOperands[0]), std::get<1>(nOperands[1])};
return success();
}
}
return failure();
}
};

} // namespace

void mlir::linalg::populateContractionOpRankReducingPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
// Unbatching patterns for unit batch size
patterns.add<RankReduceBatched<BatchMatmulOp, MatmulOp>>(context);
patterns.add<RankReduceBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
context);
patterns.add<RankReduceBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
context);
patterns.add<RankReduceBatched<BatchMatvecOp, MatvecOp>>(context);
patterns.add<RankReduceBatched<BatchVecmatOp, VecmatOp>>(context);

// Non-batch rank 1 reducing patterns
patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
// Batch rank 1 reducing patterns
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
context);
patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
context);

// Non-batch rank 0 reducing patterns
patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
}
Loading