66//
77// ===----------------------------------------------------------------------===//
88
9+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
910#include " mlir/Dialect/Linalg/IR/Linalg.h"
1011#include " mlir/Dialect/Linalg/Transforms/Transforms.h"
1112#include " mlir/Dialect/Linalg/Utils/Utils.h"
1213#include " mlir/Dialect/Tensor/IR/Tensor.h"
14+ #include " mlir/Dialect/UB/IR/UBOps.h"
1315#include " mlir/Dialect/Utils/IndexingUtils.h"
1416#include " mlir/IR/Dominance.h"
1517#include " llvm/ADT/SetOperations.h"
@@ -1236,6 +1238,269 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
12361238 ControlPropagationFn controlFn;
12371239};
12381240
1241+ // This struct contains infomation about extract_slice dims.
1242+ struct SliceDimInfo {
1243+ OpFoldResult offset;
1244+ OpFoldResult sliceSize;
1245+ OpFoldResult outputSize;
1246+ };
1247+
1248+ // / Return the first input extract slice operand, if present, for the current
1249+ // / generic op.
1250+ static FailureOr<std::tuple<OpOperand *, unsigned >>
1251+ getSliceOperandAndIndex (GenericOp genericOp) {
1252+ OpOperand *sliceOperand = nullptr ;
1253+ unsigned operandIndex;
1254+ for (auto [idx, operand] : llvm::enumerate (genericOp.getDpsInputOperands ())) {
1255+ auto extractOp = operand->get ().getDefiningOp <tensor::ExtractSliceOp>();
1256+ if (!extractOp)
1257+ continue ;
1258+ sliceOperand = operand;
1259+ operandIndex = idx;
1260+ break ;
1261+ }
1262+ if (!sliceOperand) {
1263+ return failure ();
1264+ }
1265+ return std::make_tuple (sliceOperand, operandIndex);
1266+ }
1267+
1268+ // Return a map of dims that have non full slices on them so that other operands
1269+ // can use this information. Also return a bool mentioning if a reduction dim
1270+ // has a non full slice as that can be used to fold the original extract slice.
1271+ static FailureOr<std::tuple<llvm::DenseMap<int64_t , SliceDimInfo>, bool >>
1272+ getNonFullSliceDimInfo (GenericOp genericOp, OpOperand *sliceOperand,
1273+ tensor::ExtractSliceOp producerSliceOp) {
1274+ llvm::DenseMap<int64_t , SliceDimInfo> nonZeroSliceDimMap;
1275+ bool hasNonZeroReductionDimSlice = false ;
1276+ SmallVector<utils::IteratorType> iterators =
1277+ genericOp.getIteratorTypesArray ();
1278+ SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets ();
1279+ SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes ();
1280+
1281+ SmallVector<OpFoldResult> shape = llvm::map_to_vector (
1282+ producerSliceOp.getSourceType ().getShape (),
1283+ [&](int64_t sz) -> OpFoldResult {
1284+ return getAsIndexOpFoldResult (genericOp.getContext (), sz);
1285+ });
1286+
1287+ for (auto [idx, expr] : llvm::enumerate (
1288+ genericOp.getMatchingIndexingMap (sliceOperand).getResults ())) {
1289+ if (isConstantIntValue (offsets[idx], 0 ) &&
1290+ isEqualConstantIntOrValue (sizes[idx], shape[idx])) {
1291+ continue ;
1292+ }
1293+ if (!isa<AffineDimExpr>(expr)) {
1294+ return failure ();
1295+ }
1296+ SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
1297+ int64_t dimPos = cast<AffineDimExpr>(expr).getPosition ();
1298+ nonZeroSliceDimMap[dimPos] = sliceDimInfo;
1299+ if (iterators[dimPos] == utils::IteratorType::reduction) {
1300+ hasNonZeroReductionDimSlice = true ;
1301+ }
1302+ }
1303+ // Next check if the dims with non zero slice info are used as non
1304+ // AffineDimExpr and if they are then bail-out.
1305+ for (OpOperand &operand : genericOp->getOpOperands ()) {
1306+ if (operand == *sliceOperand) {
1307+ continue ;
1308+ }
1309+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap (&operand);
1310+ if (llvm::any_of (IndexingMap.getResults (), [&](AffineExpr expr) {
1311+ if (isa<AffineDimExpr>(expr)) {
1312+ return false ;
1313+ }
1314+ WalkResult status = expr.walk ([&](AffineExpr expr) {
1315+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
1316+ if (nonZeroSliceDimMap.count (dimExpr.getPosition ()) != 0 ) {
1317+ return WalkResult::interrupt ();
1318+ }
1319+ }
1320+ return WalkResult::advance ();
1321+ });
1322+ if (status.wasInterrupted ()) {
1323+ return true ;
1324+ }
1325+ return false ;
1326+ })) {
1327+ return failure ();
1328+ }
1329+ }
1330+ return std::make_tuple (nonZeroSliceDimMap, hasNonZeroReductionDimSlice);
1331+ }
1332+
1333+ static FailureOr<std::tuple<GenericOp, Value>>
1334+ pushDownExtractSliceOpThroughGenericOp (RewriterBase &rewriter,
1335+ GenericOp genericOp,
1336+ ControlPropagationFn controlFn) {
1337+ if (genericOp.getNumResults () != 1 )
1338+ return failure ();
1339+ if (hasGatherSemantics (genericOp))
1340+ return failure ();
1341+ // Collect the unPacked operand, if present.
1342+ auto maybeSliceOperandAndIndex = getSliceOperandAndIndex (genericOp);
1343+ if (failed (maybeSliceOperandAndIndex))
1344+ return failure ();
1345+ OpOperand *sliceOperand = std::get<0 >(*maybeSliceOperandAndIndex);
1346+ unsigned OperandIndex = std::get<1 >(*maybeSliceOperandAndIndex);
1347+
1348+ if (!controlFn (sliceOperand))
1349+ return failure ();
1350+
1351+ tensor::ExtractSliceOp producerSliceOp =
1352+ sliceOperand->get ().getDefiningOp <tensor::ExtractSliceOp>();
1353+ assert (producerSliceOp && " expect a valid UnPackOp" );
1354+
1355+ if (producerSliceOp.getSource ().getType ().getRank () !=
1356+ producerSliceOp.getResult ().getType ().getRank ()) {
1357+ return failure ();
1358+ }
1359+
1360+ SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides ();
1361+ if (!areAllConstantIntValue (strides, 1 ))
1362+ return failure ();
1363+
1364+ SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets ();
1365+ SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes ();
1366+
1367+ // check if we can support the propagation of this extractSlice
1368+ // through the generic op and if so return the dimensions that
1369+
1370+ auto maybeNonZeroSliceDimMap =
1371+ getNonFullSliceDimInfo (genericOp, sliceOperand, producerSliceOp);
1372+
1373+ if (failed (maybeNonZeroSliceDimMap)) {
1374+ return failure ();
1375+ }
1376+
1377+ auto nonZeroSliceDimMap = std::get<0 >(*maybeNonZeroSliceDimMap);
1378+ bool hasNonZeroReductionDimSlice = std::get<1 >(*maybeNonZeroSliceDimMap);
1379+
1380+ // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
1381+ Location loc = genericOp->getLoc ();
1382+ AffineExpr dim0, dim1;
1383+ bindDims (rewriter.getContext (), dim0, dim1);
1384+ auto subMap = AffineMap::get (2 , 0 , {dim0 - dim1});
1385+ auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
1386+ return affine::makeComposedFoldedAffineApply (rewriter, loc, subMap,
1387+ {v1, v2});
1388+ };
1389+
1390+ MLIRContext *ctx = genericOp.getContext ();
1391+ SmallVector<Value> paddedInputs;
1392+ for (auto [idx, operand] : llvm::enumerate (genericOp.getDpsInputOperands ())) {
1393+ if (idx == OperandIndex && !hasNonZeroReductionDimSlice) {
1394+ paddedInputs.push_back (producerSliceOp.getSource ());
1395+ continue ;
1396+ }
1397+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap (operand);
1398+ SmallVector<OpFoldResult> operandLowPads (IndexingMap.getNumResults (),
1399+ getAsIndexOpFoldResult (ctx, 0 ));
1400+ SmallVector<OpFoldResult> operandHighPads (IndexingMap.getNumResults (),
1401+ getAsIndexOpFoldResult (ctx, 0 ));
1402+ for (auto [idx, expr] : llvm::enumerate (IndexingMap.getResults ())) {
1403+ if (!isa<AffineDimExpr>(expr)) {
1404+ continue ;
1405+ }
1406+ AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1407+ if (nonZeroSliceDimMap.contains (dimExpr.getPosition ())) {
1408+ SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition ()];
1409+ operandLowPads[idx] = sliceDimInfo.offset ;
1410+ operandHighPads[idx] =
1411+ sub (sub (sliceDimInfo.outputSize , sliceDimInfo.offset ),
1412+ sliceDimInfo.sliceSize );
1413+ }
1414+ }
1415+ auto paddingValue = ub::PoisonOp::create (
1416+ rewriter, loc, getElementTypeOrSelf (operand->get ().getType ()));
1417+ auto paddedOperand = tensor::PadOp::create (
1418+ rewriter, loc, Type (), operand->get (), operandLowPads, operandHighPads,
1419+ paddingValue, /* nofold=*/ false );
1420+ paddedInputs.push_back (paddedOperand);
1421+ }
1422+ AffineMap outputIndexingMap =
1423+ genericOp.getMatchingIndexingMap (genericOp.getDpsInitOperand (0 ));
1424+
1425+ auto outputShapeType =
1426+ llvm::cast<ShapedType>(genericOp.getDpsInitOperand (0 )->get ().getType ());
1427+ SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector (
1428+ outputShapeType.getShape (),
1429+ [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr (sz); });
1430+ SmallVector<OpFoldResult> newSizes = OutputShape;
1431+ SmallVector<OpFoldResult> outputLowPads (outputIndexingMap.getNumResults (),
1432+ getAsIndexOpFoldResult (ctx, 0 ));
1433+ SmallVector<OpFoldResult> outputHighPads (outputIndexingMap.getNumResults (),
1434+ getAsIndexOpFoldResult (ctx, 0 ));
1435+ SmallVector<OpFoldResult> newStrides (outputIndexingMap.getNumResults (),
1436+ getAsIndexOpFoldResult (ctx, 1 ));
1437+ for (auto [idx, expr] : llvm::enumerate (outputIndexingMap.getResults ())) {
1438+ if (!isa<AffineDimExpr>(expr)) {
1439+ continue ;
1440+ }
1441+ AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1442+ if (nonZeroSliceDimMap.contains (dimExpr.getPosition ())) {
1443+ SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition ()];
1444+ outputLowPads[idx] = sliceDimInfo.offset ;
1445+ outputHighPads[idx] =
1446+ sub (sub (sliceDimInfo.outputSize , sliceDimInfo.offset ),
1447+ sliceDimInfo.sliceSize );
1448+ OutputShape[idx] = sliceDimInfo.outputSize ;
1449+ newSizes[idx] = sliceDimInfo.sliceSize ;
1450+ }
1451+ }
1452+ Value newPadOutput;
1453+ auto outputElType =
1454+ getElementTypeOrSelf (genericOp.getDpsInits ()[0 ].getType ());
1455+ if (isGenericOutsNotUsed (genericOp)) {
1456+ newPadOutput =
1457+ tensor::EmptyOp::create (rewriter, loc, OutputShape, outputElType);
1458+
1459+ } else {
1460+
1461+ auto paddingValue = ub::PoisonOp::create (rewriter, loc, outputElType);
1462+ newPadOutput = tensor::PadOp::create (
1463+ rewriter, loc, Type (), genericOp.getDpsInits ()[0 ], outputLowPads,
1464+ outputHighPads, paddingValue, /* nofold=*/ false );
1465+ }
1466+
1467+ auto newGenericOp = linalg::GenericOp::create (
1468+ rewriter, loc, newPadOutput.getType (), paddedInputs, {newPadOutput},
1469+ genericOp.getIndexingMapsArray (), genericOp.getIteratorTypesArray (),
1470+ /* bodyBuild=*/ nullptr , linalg::getPrunedAttributeList (genericOp));
1471+ rewriter.cloneRegionBefore (genericOp.getRegion (), newGenericOp.getRegion (),
1472+ newGenericOp.getRegion ().begin ());
1473+
1474+ auto extractOp = tensor::ExtractSliceOp::create (
1475+ rewriter, loc,
1476+ newGenericOp.getTiedOpResult (newGenericOp.getDpsInitOperand (0 )),
1477+ outputLowPads, newSizes, newStrides);
1478+ Value extractRes = extractOp.getResult ();
1479+
1480+ return std::make_tuple (newGenericOp, extractRes);
1481+ }
1482+
1483+ class PushDownExtractSliceOpThroughGenericOp final
1484+ : public OpRewritePattern<GenericOp> {
1485+ public:
1486+ PushDownExtractSliceOpThroughGenericOp (MLIRContext *context,
1487+ ControlPropagationFn fun)
1488+ : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1489+
1490+ LogicalResult matchAndRewrite (GenericOp genericOp,
1491+ PatternRewriter &rewriter) const override {
1492+ auto genericAndRepl =
1493+ pushDownExtractSliceOpThroughGenericOp (rewriter, genericOp, controlFn);
1494+ if (failed (genericAndRepl))
1495+ return failure ();
1496+ rewriter.replaceOp (genericOp, std::get<1 >(*genericAndRepl));
1497+ return success ();
1498+ }
1499+
1500+ private:
1501+ ControlPropagationFn controlFn;
1502+ };
1503+
12391504} // namespace
12401505
12411506void mlir::linalg::populateDataLayoutPropagationPatterns (
@@ -1247,3 +1512,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
12471512 PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
12481513 patterns.getContext (), controlPackUnPackPropagation);
12491514}
1515+
1516+ void mlir::linalg::populateExtractSliceSinkingPatterns (
1517+ RewritePatternSet &patterns,
1518+ const ControlPropagationFn &controlPackUnPackPropagation) {
1519+ patterns.insert <PushDownExtractSliceOpThroughGenericOp>(
1520+ patterns.getContext (), controlPackUnPackPropagation);
1521+ }
0 commit comments