Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1692,6 +1692,13 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
const ControlBlockPackMatmulFn &controlFn);

/// Adds patterns that that reduce the rank of named contraction ops that have
/// unit dimensions in the operand(s) by converting to a senquence of `collapse_shape`,
/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For example a
/// `linalg.batch_matmul` with unit batch size will convert to `linalg.matmul`
/// and a `linalg.matvec` with with unit spatial dim in lhs will convert to a `linalg.dot`.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);

} // namespace linalg
} // namespace mlir

Expand Down
213 changes: 213 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -833,4 +833,217 @@ struct LinalgFoldUnitExtentDimsPass
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
}
};

} // namespace

namespace {

static SmallVector<ReassociationIndices>
getReassociationsForTrailingDims(int64_t rank) {
SmallVector<ReassociationIndices> reassociation(rank - 1, {});
if (rank > 1) {
reassociation[rank - 2] =
(rank == 1) ? ReassociationIndices{0} : ReassociationIndices{0, 1};
for (int64_t i = 0; i < rank - 2; i++)
reassociation[i] = {i};
}
return reassociation;
}

static SmallVector<ReassociationIndices>
getReassociationsForLeadingDims(int64_t rank) {
SmallVector<ReassociationIndices> reassociation(rank - 1, {});
if (rank > 1) {
reassociation[0] =
(rank == 1) ? ReassociationIndices{0} : ReassociationIndices{0, 1};
for (int64_t i = 1; i < rank - 1; i++)
reassociation[i] = {i + rank - 2};
}
return reassociation;
}

static Value collapseLeadingSingletonDim(PatternRewriter &rewriter, Value val) {
auto valType = cast<ShapedType>(val.getType());
return collapseValue(
rewriter, val.getLoc(), val, valType.getShape().drop_front(1),
getReassociationsForLeadingDims(valType.getRank()),
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
}

static Value collapseTrailingSingletonDim(PatternRewriter &rewriter,
Value val) {
auto valType = cast<ShapedType>(val.getType());
return collapseValue(
rewriter, val.getLoc(), val, valType.getShape().drop_back(1),
getReassociationsForTrailingDims(valType.getRank()),
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
}

template <typename FromOpTy, typename ToOpTy>
struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
using OpRewritePattern<FromOpTy>::OpRewritePattern;

LogicalResult matchAndRewrite(FromOpTy batchMatmulOp,
PatternRewriter &rewriter) const override {

auto loc = batchMatmulOp.getLoc();
auto inputs = batchMatmulOp.getDpsInputs();
auto inits = batchMatmulOp.getDpsInits();
if (inputs.size() != 2 || inits.size() != 1)
return rewriter.notifyMatchFailure(batchMatmulOp,
"expected 2 inputs and 1 init");
auto lhs = inputs[0];
auto rhs = inputs[1];
auto init = inits[0];

if (!checkTypes(lhs, rhs, init))
return rewriter.notifyMatchFailure(batchMatmulOp,
"no reducable dims found");

auto collapsedOperands = collapseOperands(rewriter, lhs, rhs, init);
auto collapsedLhs = collapsedOperands[0];
auto collapsedRhs = collapsedOperands[1];
auto collapsedInit = collapsedOperands[2];
SmallVector<Type, 1> collapsedResultTy;
if (isa<RankedTensorType>(collapsedInit.getType()))
collapsedResultTy.push_back(collapsedInit.getType());
auto collapsedOp = rewriter.create<ToOpTy>(
loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
ValueRange{collapsedInit});
for (auto attr : batchMatmulOp->getAttrs()) {
if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
continue;
collapsedOp->setAttr(attr.getName(), attr.getValue());
}

auto results = batchMatmulOp.getResults();
assert(results.size() < 2 && "expected at most one result");
if (results.size() < 1)
rewriter.replaceOp(batchMatmulOp, collapsedOp);
else
rewriter.replaceOp(
batchMatmulOp,
expandResult(rewriter, collapsedOp.getResultTensors()[0],
cast<RankedTensorType>(results[0].getType())));

return success();
}

virtual bool checkTypes(Value lhs, Value rhs, Value init) const = 0;
virtual SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter,
Value lhs, Value rhs,
Value init) const = 0;
virtual Value expandResult(PatternRewriter &rewriter, Value result,
RankedTensorType expandedType) const = 0;
};

template <typename FromOpTy, typename ToOpTy>
struct RankReduceBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;

bool checkTypes(Value lhs, Value rhs, Value init) const override {
auto lhsType = cast<ShapedType>(lhs.getType());
auto rhsType = cast<ShapedType>(rhs.getType());
auto initType = cast<ShapedType>(init.getType());
return lhsType.getShape()[0] == 1 && rhsType.getShape()[0] == 1 &&
initType.getShape()[0] == 1;
}

SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter, Value lhs,
Value rhs, Value init) const override {
auto collapsedLhs = collapseLeadingSingletonDim(rewriter, lhs);
auto collapsedRhs = collapseLeadingSingletonDim(rewriter, rhs);
auto collapsedInit = collapseLeadingSingletonDim(rewriter, init);
return SmallVector<Value, 3>{collapsedLhs, collapsedRhs, collapsedInit};
}
Value expandResult(PatternRewriter &rewriter, Value result,
RankedTensorType expandedType) const override {
return rewriter.create<tensor::ExpandShapeOp>(
result.getLoc(), expandedType, result,
getReassociationsForLeadingDims(expandedType.getRank()));
}
};

template <typename FromOpTy, typename ToOpTy>
struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;

static bool constexpr isTranspose =
std::is_same<FromOpTy, MatmulTransposeAOp>::value ||
std::is_same<FromOpTy, MatmulTransposeBOp>::value;

static bool constexpr reduceLeft =
(std::is_same<FromOpTy, MatmulOp>::value &&
std::is_same<ToOpTy, VecmatOp>::value) ||
(std::is_same<FromOpTy, MatmulTransposeAOp>::value &&
std::is_same<ToOpTy, VecmatOp>::value) ||
(std::is_same<FromOpTy, MatvecOp>::value &&
std::is_same<ToOpTy, DotOp>::value);

bool checkTypes(Value lhs, Value rhs, Value init) const override {
auto lhsType = cast<ShapedType>(lhs.getType());
auto rhsType = cast<ShapedType>(rhs.getType());
auto initType = cast<ShapedType>(init.getType());
int constexpr offset = (int)isTranspose;
if constexpr (reduceLeft)
return lhsType.getShape().begin()[offset] == 1 &&
initType.getShape().begin()[offset] == 1;
else
return rhsType.getShape().rbegin()[offset] == 1 &&
initType.getShape().rbegin()[offset] == 1;
}

SmallVector<Value, 3> collapseOperands(PatternRewriter &rewriter, Value lhs,
Value rhs, Value init) const override {
if constexpr (reduceLeft) {
if constexpr (isTranspose) {
lhs = collapseTrailingSingletonDim(rewriter, lhs);
init = collapseTrailingSingletonDim(rewriter, init);
} else {
lhs = collapseLeadingSingletonDim(rewriter, lhs);
init = collapseLeadingSingletonDim(rewriter, init);
}
} else {
if constexpr (isTranspose) {
rhs = collapseLeadingSingletonDim(rewriter, rhs);
init = collapseLeadingSingletonDim(rewriter, init);
} else {
rhs = collapseTrailingSingletonDim(rewriter, rhs);
init = collapseTrailingSingletonDim(rewriter, init);
}
}
return SmallVector<Value, 3>{lhs, rhs, init};
}

Value expandResult(PatternRewriter &rewriter, Value result,
RankedTensorType expandedType) const override {
if constexpr (reduceLeft)
return rewriter.create<tensor::ExpandShapeOp>(
result.getLoc(), expandedType, result,
getReassociationsForLeadingDims(expandedType.getRank()));
else
return rewriter.create<tensor::ExpandShapeOp>(
result.getLoc(), expandedType, result,
getReassociationsForTrailingDims(expandedType.getRank()));
}
};

} // namespace

void mlir::linalg::populateContractionOpRankReducingPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<RankReduceBatched<BatchMatmulOp, MatmulOp>>(context);
patterns.add<RankReduceBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
context);
patterns.add<RankReduceBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
context);
patterns.add<RankReduceBatched<BatchMatvecOp, MatvecOp>>(context);
patterns.add<RankReduceBatched<BatchVecmatOp, VecmatOp>>(context);
patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
}
Loading