66//
77// ===----------------------------------------------------------------------===//
88
9+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
910#include " mlir/Dialect/Linalg/IR/Linalg.h"
1011#include " mlir/Dialect/Linalg/Transforms/Transforms.h"
1112#include " mlir/Dialect/Linalg/Utils/Utils.h"
1213#include " mlir/Dialect/Tensor/IR/Tensor.h"
14+ #include " mlir/Dialect/UB/IR/UBOps.h"
1315#include " mlir/Dialect/Utils/IndexingUtils.h"
1416#include " mlir/IR/Dominance.h"
1517#include " llvm/ADT/SetOperations.h"
@@ -1236,6 +1238,272 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
12361238 ControlPropagationFn controlFn;
12371239};
12381240
1241+ // This struct contains infomation about extract_slice dims.
1242+ struct SliceDimInfo {
1243+ OpFoldResult offset;
1244+ OpFoldResult sliceSize;
1245+ OpFoldResult outputSize;
1246+ };
1247+
1248+ // / Return the first input extract slice operand, if present, for the current
1249+ // / generic op.
1250+ static FailureOr<OpOperand *> getSliceOperand (GenericOp genericOp) {
1251+ OpOperand *sliceOperand = nullptr ;
1252+ for (auto operand : genericOp.getDpsInputOperands ()) {
1253+ auto extractOp = operand->get ().getDefiningOp <tensor::ExtractSliceOp>();
1254+ if (!extractOp)
1255+ continue ;
1256+ sliceOperand = operand;
1257+ break ;
1258+ }
1259+ if (!sliceOperand) {
1260+ return failure ();
1261+ }
1262+ return sliceOperand;
1263+ }
1264+
1265+ // Return a map of dims that have partial slices on them so that other operands
1266+ // can use this information. Also return a bool mentioning if a reduction dim
1267+ // has a non full slice as that can be used to fold the original extract slice.
1268+ static FailureOr<llvm::DenseMap<int64_t , SliceDimInfo>>
1269+ getPartialSliceDimInfo (GenericOp genericOp, OpOperand *sliceOperand) {
1270+ tensor::ExtractSliceOp producerSliceOp =
1271+ sliceOperand->get ().getDefiningOp <tensor::ExtractSliceOp>();
1272+ assert (producerSliceOp && " expect a valid ExtractSliceOp" );
1273+ llvm::DenseMap<int64_t , SliceDimInfo> partialSliceDimMap;
1274+ SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets ();
1275+ SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes ();
1276+
1277+ SmallVector<OpFoldResult> shape = getAsIndexOpFoldResult (
1278+ genericOp.getContext (), producerSliceOp.getSourceType ().getShape ());
1279+
1280+ for (auto [idx, expr] : llvm::enumerate (
1281+ genericOp.getMatchingIndexingMap (sliceOperand).getResults ())) {
1282+ // If we have a full slice in a dimension then we dont need to add it to
1283+ // the partial slice map.
1284+ if (isConstantIntValue (offsets[idx], 0 ) &&
1285+ isEqualConstantIntOrValue (sizes[idx], shape[idx])) {
1286+ continue ;
1287+ }
1288+ // We only support partial slices of AffineDimExprs so bail-out if thats not
1289+ // the case.
1290+ if (!isa<AffineDimExpr>(expr)) {
1291+ return failure ();
1292+ }
1293+ SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
1294+ int64_t dimPos = cast<AffineDimExpr>(expr).getPosition ();
1295+ partialSliceDimMap[dimPos] = sliceDimInfo;
1296+ }
1297+ // Next check if the dims with partial slice info are used in non
1298+ // AffineDimExpr in other operands and if they are then bail-out.
1299+ for (OpOperand &operand : genericOp->getOpOperands ()) {
1300+ if (operand == *sliceOperand) {
1301+ continue ;
1302+ }
1303+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap (&operand);
1304+ if (llvm::any_of (IndexingMap.getResults (), [&](AffineExpr expr) {
1305+ if (isa<AffineDimExpr>(expr)) {
1306+ return false ;
1307+ }
1308+ WalkResult status = expr.walk ([&](AffineExpr expr) {
1309+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
1310+ if (partialSliceDimMap.contains (dimExpr.getPosition ())) {
1311+ return WalkResult::interrupt ();
1312+ }
1313+ }
1314+ return WalkResult::advance ();
1315+ });
1316+ if (status.wasInterrupted ()) {
1317+ return true ;
1318+ }
1319+ return false ;
1320+ })) {
1321+ return failure ();
1322+ }
1323+ }
1324+ return partialSliceDimMap;
1325+ }
1326+
1327+ static FailureOr<std::tuple<GenericOp, Value>>
1328+ pushDownExtractSliceOpThroughGenericOp (RewriterBase &rewriter,
1329+ GenericOp genericOp,
1330+ ControlPropagationFn controlFn) {
1331+ if (genericOp.getNumResults () != 1 )
1332+ return rewriter.notifyMatchFailure (
1333+ genericOp, " propagation through multi-result generic is unsupported." );
1334+ if (hasGatherSemantics (genericOp))
1335+ return rewriter.notifyMatchFailure (
1336+ genericOp,
1337+ " propagation through generic with gather semantics is unsupported." );
1338+ // Collect the sliced operand, if present.
1339+ auto maybeSliceOperand = getSliceOperand (genericOp);
1340+ if (failed (maybeSliceOperand))
1341+ return failure ();
1342+ OpOperand *sliceOperand = *maybeSliceOperand;
1343+ unsigned OperandIndex = sliceOperand->getOperandNumber ();
1344+
1345+ if (!controlFn (sliceOperand))
1346+ return failure ();
1347+
1348+ tensor::ExtractSliceOp producerSliceOp =
1349+ sliceOperand->get ().getDefiningOp <tensor::ExtractSliceOp>();
1350+ assert (producerSliceOp && " expect a valid ExtractSliceOp" );
1351+
1352+ if (producerSliceOp.getSource ().getType ().getRank () !=
1353+ producerSliceOp.getResult ().getType ().getRank ()) {
1354+ return rewriter.notifyMatchFailure (
1355+ genericOp,
1356+ " propagation of rank-reducing extract slice is unsupported." );
1357+ }
1358+
1359+ SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides ();
1360+ if (!areAllConstantIntValue (strides, 1 ))
1361+ return rewriter.notifyMatchFailure (
1362+ genericOp, " propagation of strided extract slice is unsupported." );
1363+
1364+ // check if we can support the propagation of this extractSlice
1365+ // through the generic op and if so return the dimensions that
1366+
1367+ auto maybePartialSliceDimMap =
1368+ getPartialSliceDimInfo (genericOp, sliceOperand);
1369+
1370+ if (failed (maybePartialSliceDimMap)) {
1371+ return failure ();
1372+ }
1373+
1374+ auto partialSliceDimMap = *maybePartialSliceDimMap;
1375+
1376+ SmallVector<utils::IteratorType> iterators =
1377+ genericOp.getIteratorTypesArray ();
1378+ bool hasPartialReductionDimSlice =
1379+ llvm::any_of (partialSliceDimMap, [&](const auto &slice) {
1380+ int64_t sliceDim = slice.first ;
1381+ return iterators[sliceDim] == utils::IteratorType::reduction;
1382+ });
1383+
1384+ // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
1385+ Location loc = genericOp->getLoc ();
1386+ AffineExpr dim0, dim1;
1387+ bindDims (rewriter.getContext (), dim0, dim1);
1388+ auto subMap = AffineMap::get (2 , 0 , {dim0 - dim1});
1389+ auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
1390+ return affine::makeComposedFoldedAffineApply (rewriter, loc, subMap,
1391+ {v1, v2});
1392+ };
1393+
1394+ MLIRContext *ctx = genericOp.getContext ();
1395+ SmallVector<Value> paddedInputs;
1396+ for (auto [idx, operand] : llvm::enumerate (genericOp.getDpsInputOperands ())) {
1397+ if (idx == OperandIndex && !hasPartialReductionDimSlice) {
1398+ paddedInputs.push_back (producerSliceOp.getSource ());
1399+ continue ;
1400+ }
1401+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap (operand);
1402+ SmallVector<OpFoldResult> operandLowPads (IndexingMap.getNumResults (),
1403+ getAsIndexOpFoldResult (ctx, 0 ));
1404+ SmallVector<OpFoldResult> operandHighPads (IndexingMap.getNumResults (),
1405+ getAsIndexOpFoldResult (ctx, 0 ));
1406+ for (auto [idx, expr] : llvm::enumerate (IndexingMap.getResults ())) {
1407+ if (!isa<AffineDimExpr>(expr)) {
1408+ continue ;
1409+ }
1410+ AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1411+ if (!partialSliceDimMap.contains (dimExpr.getPosition ())) {
1412+ continue ;
1413+ }
1414+ SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition ()];
1415+ operandLowPads[idx] = sliceDimInfo.offset ;
1416+ operandHighPads[idx] =
1417+ sub (sub (sliceDimInfo.outputSize , sliceDimInfo.offset ),
1418+ sliceDimInfo.sliceSize );
1419+ }
1420+ auto paddingValue = ub::PoisonOp::create (
1421+ rewriter, loc, getElementTypeOrSelf (operand->get ().getType ()));
1422+ auto paddedOperand = tensor::PadOp::create (
1423+ rewriter, loc, Type (), operand->get (), operandLowPads, operandHighPads,
1424+ paddingValue, /* nofold=*/ false );
1425+ paddedInputs.push_back (paddedOperand);
1426+ }
1427+ AffineMap outputIndexingMap =
1428+ genericOp.getMatchingIndexingMap (genericOp.getDpsInitOperand (0 ));
1429+
1430+ auto outputShapeType =
1431+ llvm::cast<ShapedType>(genericOp.getDpsInitOperand (0 )->get ().getType ());
1432+ SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector (
1433+ outputShapeType.getShape (),
1434+ [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr (sz); });
1435+ SmallVector<OpFoldResult> newSizes = OutputShape;
1436+ SmallVector<OpFoldResult> outputLowPads (outputIndexingMap.getNumResults (),
1437+ getAsIndexOpFoldResult (ctx, 0 ));
1438+ SmallVector<OpFoldResult> outputHighPads (outputIndexingMap.getNumResults (),
1439+ getAsIndexOpFoldResult (ctx, 0 ));
1440+ SmallVector<OpFoldResult> newStrides (outputIndexingMap.getNumResults (),
1441+ getAsIndexOpFoldResult (ctx, 1 ));
1442+ for (auto [idx, expr] : llvm::enumerate (outputIndexingMap.getResults ())) {
1443+ if (!isa<AffineDimExpr>(expr)) {
1444+ continue ;
1445+ }
1446+ AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1447+ if (!partialSliceDimMap.contains (dimExpr.getPosition ())) {
1448+ continue ;
1449+ }
1450+ SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition ()];
1451+ outputLowPads[idx] = sliceDimInfo.offset ;
1452+ outputHighPads[idx] = sub (sub (sliceDimInfo.outputSize , sliceDimInfo.offset ),
1453+ sliceDimInfo.sliceSize );
1454+ OutputShape[idx] = sliceDimInfo.outputSize ;
1455+ newSizes[idx] = sliceDimInfo.sliceSize ;
1456+ }
1457+ Value newPadOutput;
1458+ auto outputElType =
1459+ getElementTypeOrSelf (genericOp.getDpsInits ()[0 ].getType ());
1460+ if (isGenericOutsNotUsed (genericOp)) {
1461+ newPadOutput =
1462+ tensor::EmptyOp::create (rewriter, loc, OutputShape, outputElType);
1463+ } else {
1464+ auto paddingValue = ub::PoisonOp::create (rewriter, loc, outputElType);
1465+ newPadOutput = tensor::PadOp::create (
1466+ rewriter, loc, Type (), genericOp.getDpsInits ()[0 ], outputLowPads,
1467+ outputHighPads, paddingValue, /* nofold=*/ false );
1468+ }
1469+
1470+ auto newGenericOp = linalg::GenericOp::create (
1471+ rewriter, loc, newPadOutput.getType (), paddedInputs, {newPadOutput},
1472+ genericOp.getIndexingMapsArray (), genericOp.getIteratorTypesArray (),
1473+ /* bodyBuild=*/ nullptr , linalg::getPrunedAttributeList (genericOp));
1474+ rewriter.cloneRegionBefore (genericOp.getRegion (), newGenericOp.getRegion (),
1475+ newGenericOp.getRegion ().begin ());
1476+
1477+ auto extractOp = tensor::ExtractSliceOp::create (
1478+ rewriter, loc,
1479+ newGenericOp.getTiedOpResult (newGenericOp.getDpsInitOperand (0 )),
1480+ outputLowPads, newSizes, newStrides);
1481+ Value extractRes = extractOp.getResult ();
1482+
1483+ return std::make_tuple (newGenericOp, extractRes);
1484+ }
1485+
1486+ class PushDownExtractSliceOpThroughGenericOp final
1487+ : public OpRewritePattern<GenericOp> {
1488+ public:
1489+ PushDownExtractSliceOpThroughGenericOp (MLIRContext *context,
1490+ ControlPropagationFn fun)
1491+ : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1492+
1493+ LogicalResult matchAndRewrite (GenericOp genericOp,
1494+ PatternRewriter &rewriter) const override {
1495+ auto genericAndRepl =
1496+ pushDownExtractSliceOpThroughGenericOp (rewriter, genericOp, controlFn);
1497+ if (failed (genericAndRepl))
1498+ return failure ();
1499+ rewriter.replaceOp (genericOp, std::get<1 >(*genericAndRepl));
1500+ return success ();
1501+ }
1502+
1503+ private:
1504+ ControlPropagationFn controlFn;
1505+ };
1506+
12391507} // namespace
12401508
12411509void mlir::linalg::populateDataLayoutPropagationPatterns (
@@ -1247,3 +1515,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
12471515 PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
12481516 patterns.getContext (), controlPackUnPackPropagation);
12491517}
1518+
1519+ void mlir::linalg::populateExtractSliceSinkingPatterns (
1520+ RewritePatternSet &patterns,
1521+ const ControlPropagationFn &controlPackUnPackPropagation) {
1522+ patterns.insert <PushDownExtractSliceOpThroughGenericOp>(
1523+ patterns.getContext (), controlPackUnPackPropagation);
1524+ }
0 commit comments