Skip to content

Commit 6455d47

Browse files
authored
Get rid of ReduceRankOp completely (#152)
1 parent 2a9cc20 commit 6455d47

File tree

6 files changed

+32
-372
lines changed

6 files changed

+32
-372
lines changed

mlir/include/mlir-extensions/dialect/plier_util/PlierUtilOps.td

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -120,27 +120,6 @@ def SignCastOp : PlierUtil_Op<"sign_cast", [NoSideEffect]> {
120120
let hasCanonicalizer = 1;
121121
}
122122

123-
def ReduceRankOp
124-
: PlierUtil_Op<"reduce_rank", [ViewLikeOpInterface, NoSideEffect]> {
125-
let arguments = (ins
126-
AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$source,
127-
I32ArrayAttr:$mapping);
128-
129-
let results = (outs AnyTypeOf<[AnyMemRef, AnyRankedTensor]>);
130-
let hasFolder = 1;
131-
let hasCanonicalizer = 1;
132-
133-
let builders = [
134-
OpBuilder<(ins "::mlir::Value" : $src,
135-
"::mlir::ArrayRef<int32_t>" : $mapping)>
136-
];
137-
138-
let extraClassDeclaration = [{
139-
::mlir::Value getViewSource() { return source(); }
140-
::llvm::SmallVector<int32_t> getMapping();
141-
}];
142-
}
143-
144123
def ExtractMemrefMetadataOp
145124
: PlierUtil_Op<"extract_memref_metadata", [NoSideEffect]> {
146125
let arguments = (ins AnyMemRef : $source, IndexAttr : $dimIndex);

mlir/lib/dialect/plier_util/dialect.cpp

Lines changed: 21 additions & 251 deletions
Original file line numberDiff line numberDiff line change
@@ -529,30 +529,6 @@ struct ChangeLayoutSignCast : public mlir::OpRewritePattern<plier::SignCastOp> {
529529
}
530530
};
531531

532-
struct ChangeLayoutReduceRank
533-
: public mlir::OpRewritePattern<plier::ReduceRankOp> {
534-
using OpRewritePattern::OpRewritePattern;
535-
536-
mlir::LogicalResult
537-
matchAndRewrite(plier::ReduceRankOp op,
538-
mlir::PatternRewriter &rewriter) const override {
539-
auto cl = op.source().getDefiningOp<plier::ChangeLayoutOp>();
540-
if (!cl)
541-
return mlir::failure();
542-
543-
auto loc = op.getLoc();
544-
auto newOp = rewriter.createOrFold<plier::ReduceRankOp>(loc, cl.source(),
545-
op.getMapping());
546-
auto oldType = op.getType();
547-
auto newType = newOp.getType();
548-
if (oldType != newType)
549-
newOp = rewriter.createOrFold<plier::ChangeLayoutOp>(loc, oldType, newOp);
550-
551-
rewriter.replaceOp(op, newOp);
552-
return mlir::success();
553-
}
554-
};
555-
556532
struct ChangeLayoutLoad : public mlir::OpRewritePattern<mlir::memref::LoadOp> {
557533
using OpRewritePattern::OpRewritePattern;
558534

@@ -829,11 +805,19 @@ struct ChangeLayout1DReshape
829805
ArrayType strides(srcRank, rewriter.getIndexAttr(1));
830806
auto view = rewriter.createOrFold<mlir::memref::SubViewOp>(
831807
loc, source, offsets, sizes, strides);
832-
auto resType = view.getType().cast<mlir::MemRefType>();
833-
if (resType.getRank() > dstType.getRank()) {
834-
// TODO: Rank-reducing subview
835-
const int32_t mapping[1] = {static_cast<int32_t>(*srcDimIndex)};
836-
view = rewriter.createOrFold<plier::ReduceRankOp>(loc, view, mapping);
808+
auto dstRank = dstType.getRank();
809+
if (srcRank != dstRank) {
810+
assert(dstRank < srcRank);
811+
llvm::SmallVector<mlir::OpFoldResult> newOfsets(srcRank,
812+
rewriter.getIndexAttr(0));
813+
llvm::SmallVector<mlir::OpFoldResult> newStrides(
814+
srcRank, rewriter.getIndexAttr(1));
815+
auto viewType = view.getType().cast<mlir::MemRefType>();
816+
auto reducedType = mlir::memref::SubViewOp::inferRankReducedResultType(
817+
dstRank, viewType, newOfsets, sizes, newStrides)
818+
.cast<mlir::MemRefType>();
819+
view = rewriter.create<mlir::memref::SubViewOp>(
820+
loc, reducedType, view, newOfsets, sizes, newStrides);
837821
}
838822
rewriter.replaceOpWithNewOp<plier::ChangeLayoutOp>(op, dstType, view);
839823
return mlir::success();
@@ -923,14 +907,13 @@ struct ChangeLayoutExpandShape
923907

924908
void ChangeLayoutOp::getCanonicalizationPatterns(
925909
::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context) {
926-
results
927-
.insert<ChangeLayoutIdentity, ChangeLayoutReduceRank, ChangeLayoutDim,
928-
ChangeLayoutExtractMetadata, ChangeLayoutClone,
929-
PropagateCloneType, ChangeLayoutCast, ChangeLayoutSignCast,
930-
ChangeLayoutLoad, ChangeLayoutStore, ChangeLayoutSubview,
931-
ChangeLayoutLinalgGeneric, ChangeLayoutLinalgFill, ChangeLayoutIf,
932-
ChangeLayout1DReshape, ChangeLayoutSliceGetItem, ChangeLayoutCopy,
933-
ChangeLayoutExpandShape>(context);
910+
results.insert<
911+
ChangeLayoutIdentity, ChangeLayoutDim, ChangeLayoutExtractMetadata,
912+
ChangeLayoutClone, PropagateCloneType, ChangeLayoutCast,
913+
ChangeLayoutSignCast, ChangeLayoutLoad, ChangeLayoutStore,
914+
ChangeLayoutSubview, ChangeLayoutLinalgGeneric, ChangeLayoutLinalgFill,
915+
ChangeLayoutIf, ChangeLayout1DReshape, ChangeLayoutSliceGetItem,
916+
ChangeLayoutCopy, ChangeLayoutExpandShape>(context);
934917
}
935918

936919
static mlir::Value propagateCasts(mlir::Value val, mlir::Type thisType);
@@ -1207,31 +1190,6 @@ struct SignCastMemrefToTensorPropagate
12071190
}
12081191
};
12091192

1210-
struct SignCastReduceRankPropagate
1211-
: public mlir::OpRewritePattern<plier::SignCastOp> {
1212-
using OpRewritePattern::OpRewritePattern;
1213-
1214-
mlir::LogicalResult
1215-
matchAndRewrite(plier::SignCastOp op,
1216-
mlir::PatternRewriter &rewriter) const override {
1217-
auto prevOp = op.value().getDefiningOp<plier::ReduceRankOp>();
1218-
if (!prevOp)
1219-
return mlir::failure();
1220-
1221-
auto src = prevOp.source();
1222-
auto srcType = src.getType().cast<mlir::ShapedType>();
1223-
auto dstType = op.getType().cast<mlir::ShapedType>();
1224-
1225-
auto newSrcType = srcType.clone(dstType.getElementType());
1226-
1227-
auto loc = prevOp->getLoc();
1228-
auto newSrc = rewriter.create<plier::SignCastOp>(loc, newSrcType, src);
1229-
rewriter.replaceOpWithNewOp<plier::ReduceRankOp>(op, newSrc,
1230-
prevOp.getMapping());
1231-
return mlir::success();
1232-
}
1233-
};
1234-
12351193
struct SignCastMemrefSubviewPropagate
12361194
: public mlir::OpRewritePattern<plier::SignCastOp> {
12371195
using OpRewritePattern::OpRewritePattern;
@@ -1272,175 +1230,7 @@ void SignCastOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results,
12721230
SignCastAllocPropagate<mlir::memref::AllocaOp>,
12731231
SignCastTensorFromElementsPropagate, SignCastTensorCollapseShapePropagate,
12741232
SignCastTensorToMemrefPropagate, SignCastMemrefToTensorPropagate,
1275-
SignCastReduceRankPropagate, SignCastMemrefSubviewPropagate>(context);
1276-
}
1277-
1278-
void ReduceRankOp::build(::mlir::OpBuilder &odsBuilder,
1279-
::mlir::OperationState &odsState, ::mlir::Value src,
1280-
::mlir::ArrayRef<int32_t> mapping) {
1281-
assert(src.getType().isa<mlir::ShapedType>());
1282-
auto srcType = src.getType().cast<mlir::ShapedType>();
1283-
assert(srcType.hasRank());
1284-
auto srcRank = static_cast<unsigned>(srcType.getRank());
1285-
assert(!mapping.empty());
1286-
assert(llvm::all_of(mapping, [&](int32_t val) {
1287-
return val >= 0 && val < static_cast<int32_t>(srcRank);
1288-
}));
1289-
auto mapAttr = odsBuilder.getI32ArrayAttr(mapping);
1290-
auto srcShape = srcType.getShape();
1291-
llvm::SmallVector<int64_t> shape(mapping.size());
1292-
for (auto it : llvm::enumerate(mapping))
1293-
shape[it.index()] = srcShape[static_cast<size_t>(it.value())];
1294-
1295-
if (auto tensorType = srcType.dyn_cast<mlir::RankedTensorType>()) {
1296-
auto retType = mlir::RankedTensorType::get(
1297-
shape, tensorType.getElementType(), tensorType.getEncoding());
1298-
build(odsBuilder, odsState, retType, src, mapAttr);
1299-
} else if (auto memrefType = srcType.dyn_cast<mlir::MemRefType>()) {
1300-
auto affineMap = [&]() {
1301-
mlir::AffineMap ret;
1302-
if (!memrefType.getLayout().isIdentity()) {
1303-
auto affineMap = memrefType.getLayout().getAffineMap();
1304-
auto context = odsBuilder.getContext();
1305-
llvm::SmallVector<mlir::AffineExpr> dimReplacements(srcRank);
1306-
llvm::SmallVector<mlir::AffineExpr> symReplacements(srcRank + 1);
1307-
symReplacements[0] = mlir::getAffineSymbolExpr(0, context);
1308-
for (auto i : llvm::seq(0u, srcRank)) {
1309-
auto it = llvm::find(mapping, i);
1310-
if (it != mapping.end()) {
1311-
auto srcIndex = static_cast<unsigned>(it - mapping.begin());
1312-
dimReplacements[i] = mlir::getAffineDimExpr(srcIndex, context);
1313-
symReplacements[i + 1] =
1314-
mlir::getAffineSymbolExpr(srcIndex + 1, context);
1315-
} else {
1316-
dimReplacements[i] = mlir::getAffineConstantExpr(0, context);
1317-
symReplacements[i + 1] = mlir::getAffineConstantExpr(0, context);
1318-
}
1319-
}
1320-
auto dstRank = static_cast<unsigned>(mapping.size());
1321-
auto resMap = affineMap.replaceDimsAndSymbols(
1322-
dimReplacements, symReplacements, dstRank, dstRank + 1);
1323-
ret = mlir::simplifyAffineMap(resMap);
1324-
}
1325-
return ret;
1326-
}();
1327-
1328-
auto retType =
1329-
mlir::MemRefType::get(shape, memrefType.getElementType(), affineMap,
1330-
memrefType.getMemorySpace());
1331-
build(odsBuilder, odsState, retType, src, mapAttr);
1332-
} else {
1333-
llvm_unreachable("ReduceRankOp: Invalid src type");
1334-
}
1335-
}
1336-
1337-
mlir::OpFoldResult
1338-
ReduceRankOp::fold(llvm::ArrayRef<mlir::Attribute> /*operands*/) {
1339-
auto src = source();
1340-
if (src.getType() == getType()) {
1341-
return src;
1342-
}
1343-
return nullptr;
1344-
}
1345-
1346-
llvm::SmallVector<int32_t> ReduceRankOp::getMapping() {
1347-
auto m = mapping();
1348-
llvm::SmallVector<int32_t> ret(m.size());
1349-
llvm::transform(m, ret.begin(), [](mlir::Attribute a) {
1350-
return a.cast<mlir::IntegerAttr>().getValue().getSExtValue();
1351-
});
1352-
return ret;
1353-
}
1354-
1355-
namespace {
1356-
template <typename Op>
1357-
struct ReduceRankDimPropagate : public mlir::OpRewritePattern<Op> {
1358-
using mlir::OpRewritePattern<Op>::OpRewritePattern;
1359-
1360-
mlir::LogicalResult
1361-
matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
1362-
auto index = mlir::getConstantIntValue(op.index());
1363-
if (!index)
1364-
return mlir::failure();
1365-
1366-
auto prev = op.source().template getDefiningOp<plier::ReduceRankOp>();
1367-
if (!prev)
1368-
return mlir::failure();
1369-
1370-
auto mappedArg = prev.mapping()[*index]
1371-
.template cast<mlir::IntegerAttr>()
1372-
.getValue()
1373-
.getSExtValue();
1374-
rewriter.replaceOpWithNewOp<Op>(op, prev.source(), mappedArg);
1375-
return mlir::success();
1376-
}
1377-
};
1378-
1379-
static auto mapReduceRankIndices(mlir::OpBuilder &builder, mlir::Location loc,
1380-
plier::ReduceRankOp src,
1381-
mlir::ValueRange srcIndices) {
1382-
auto srcMemref = src.getViewSource();
1383-
auto srcMemrefType = srcMemref.getType().cast<mlir::MemRefType>();
1384-
auto rank = static_cast<unsigned>(srcMemrefType.getRank());
1385-
auto zero = builder.createOrFold<mlir::arith::ConstantIndexOp>(loc, 0);
1386-
auto mapping = src.getMapping();
1387-
llvm::SmallVector<mlir::Value> indices(rank);
1388-
for (auto i : llvm::seq(0u, rank)) {
1389-
auto it = llvm::find(mapping, static_cast<int32_t>(i));
1390-
if (mapping.end() == it) {
1391-
indices[i] = zero;
1392-
} else {
1393-
auto dstIndex = static_cast<size_t>(it - mapping.begin());
1394-
indices[i] = srcIndices[dstIndex];
1395-
}
1396-
}
1397-
return indices;
1398-
}
1399-
1400-
struct ReduceRankLoadPropagate
1401-
: public mlir::OpRewritePattern<mlir::memref::LoadOp> {
1402-
using OpRewritePattern::OpRewritePattern;
1403-
1404-
mlir::LogicalResult
1405-
matchAndRewrite(mlir::memref::LoadOp op,
1406-
mlir::PatternRewriter &rewriter) const override {
1407-
auto src = op.memref().getDefiningOp<plier::ReduceRankOp>();
1408-
if (!src)
1409-
return mlir::failure();
1410-
1411-
auto indices =
1412-
mapReduceRankIndices(rewriter, op.getLoc(), src, op.indices());
1413-
rewriter.replaceOpWithNewOp<mlir::memref::LoadOp>(op, src.getViewSource(),
1414-
indices);
1415-
return mlir::success();
1416-
}
1417-
};
1418-
1419-
struct ReduceRankStorePropagate
1420-
: public mlir::OpRewritePattern<mlir::memref::StoreOp> {
1421-
using OpRewritePattern::OpRewritePattern;
1422-
1423-
mlir::LogicalResult
1424-
matchAndRewrite(mlir::memref::StoreOp op,
1425-
mlir::PatternRewriter &rewriter) const override {
1426-
auto src = op.memref().getDefiningOp<plier::ReduceRankOp>();
1427-
if (!src)
1428-
return mlir::failure();
1429-
1430-
auto indices =
1431-
mapReduceRankIndices(rewriter, op.getLoc(), src, op.indices());
1432-
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(
1433-
op, op.value(), src.getViewSource(), indices);
1434-
return mlir::success();
1435-
}
1436-
};
1437-
} // namespace
1438-
1439-
void ReduceRankOp::getCanonicalizationPatterns(
1440-
::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context) {
1441-
results.insert<ReduceRankDimPropagate<mlir::tensor::DimOp>,
1442-
ReduceRankDimPropagate<mlir::memref::DimOp>,
1443-
ReduceRankLoadPropagate, ReduceRankStorePropagate>(context);
1233+
SignCastMemrefSubviewPropagate>(context);
14441234
}
14451235

14461236
void ExtractMemrefMetadataOp::build(::mlir::OpBuilder &odsBuilder,
@@ -1502,26 +1292,6 @@ ExtractMemrefMetadataOp::fold(llvm::ArrayRef<mlir::Attribute> /*operands*/) {
15021292
return getResult();
15031293
}
15041294

1505-
if (auto reduceRank = src.getDefiningOp<plier::ReduceRankOp>()) {
1506-
auto newSrc = reduceRank.source();
1507-
if (idx == -1) {
1508-
sourceMutable().assign(newSrc);
1509-
return getResult();
1510-
}
1511-
1512-
auto mapping = reduceRank.getMapping();
1513-
if (static_cast<unsigned>(idx) < mapping.size()) {
1514-
auto newIdx = mapping[static_cast<unsigned>(idx)];
1515-
assert(newIdx >= 0);
1516-
sourceMutable().assign(newSrc);
1517-
auto type = dimIndexAttr().getType();
1518-
dimIndexAttr(mlir::IntegerAttr::get(type, newIdx));
1519-
return getResult();
1520-
}
1521-
1522-
return nullptr;
1523-
}
1524-
15251295
if (auto cast = src.getDefiningOp<mlir::memref::CastOp>()) {
15261296
auto castSrc = cast.source();
15271297
auto castSrcType = castSrc.getType().cast<mlir::ShapedType>();

mlir/lib/transforms/promote_bool_memref.cpp

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -202,25 +202,6 @@ class ConvertRetainOp : public mlir::OpConversionPattern<plier::RetainOp> {
202202
return mlir::success();
203203
}
204204
};
205-
206-
class ConvertReduceRankOp
207-
: public mlir::OpConversionPattern<plier::ReduceRankOp> {
208-
public:
209-
using OpConversionPattern::OpConversionPattern;
210-
211-
mlir::LogicalResult
212-
matchAndRewrite(plier::ReduceRankOp op, plier::ReduceRankOp::Adaptor adaptor,
213-
mlir::ConversionPatternRewriter &rewriter) const override {
214-
auto *converter = getTypeConverter();
215-
auto resType = converter->convertType(op.getType())
216-
.dyn_cast_or_null<mlir::MemRefType>();
217-
if (!resType)
218-
return mlir::failure();
219-
rewriter.replaceOpWithNewOp<plier::ReduceRankOp>(
220-
op, resType, adaptor.source(), adaptor.mapping());
221-
return mlir::success();
222-
}
223-
};
224205
} // namespace
225206

226207
void plier::populatePromoteBoolMemrefConversionRewritesAndTarget(
@@ -238,12 +219,11 @@ void plier::populatePromoteBoolMemrefConversionRewritesAndTarget(
238219
});
239220

240221
target.addDynamicallyLegalDialect<mlir::memref::MemRefDialect>(&checkOp);
241-
target.addDynamicallyLegalOp<plier::RetainOp, plier::ReduceRankOp>(&checkOp);
222+
target.addDynamicallyLegalOp<plier::RetainOp>(&checkOp);
242223

243224
patterns.insert<ConvertDimOp, ConvertLoadOp, ConvertStoreOp, ConvertAllocOp,
244225
ConvertAllocaOp, ConvertDeallocOp, ConvertCastOp,
245-
ConvertSubviewOp, ConvertRetainOp, ConvertReduceRankOp>(
246-
typeConverter, context);
226+
ConvertSubviewOp, ConvertRetainOp>(typeConverter, context);
247227
}
248228

249229
namespace {

0 commit comments

Comments
 (0)