2222#include  " mlir/IR/IRMapping.h" 
2323#include  " mlir/IR/Matchers.h" 
2424#include  " mlir/IR/OpDefinition.h" 
25+ #include  " mlir/IR/PatternMatch.h" 
2526#include  " mlir/IR/TypeUtilities.h" 
2627#include  " mlir/Interfaces/DestinationStyleOpInterface.h" 
2728#include  " mlir/Interfaces/InferIntRangeInterface.h" 
3334#include  " llvm/ADT/STLExtras.h" 
3435#include  " llvm/ADT/SmallBitVector.h" 
3536#include  " llvm/ADT/StringRef.h" 
37+ #include  " llvm/Support/Casting.h" 
3638#include  " llvm/Support/LogicalResult.h" 
3739#include  " llvm/Support/MathExtras.h" 
3840#include  < algorithm> 
3941#include  < optional> 
42+ #include  < vector> 
4043
4144using  namespace  mlir ; 
4245using  namespace  mlir ::tensor; 
@@ -1288,6 +1291,68 @@ struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
12881291  }
12891292};
12901293
1294+ // / Canonicalizes the pattern of the form
1295+ // /
1296+ // / %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
1297+ // / tensor<12xf64>
1298+ // / %extracted_element = tensor.extract %val[%c10] :
1299+ // / tensor<12xf64>
1300+ // /
1301+ // / to
1302+ // /
1303+ // / %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
1304+ struct  ExtractFromCollapseShape  : public  OpRewritePattern <tensor::ExtractOp> {
1305+   using  OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1306+ 
1307+   LogicalResult matchAndRewrite (tensor::ExtractOp extractOp,
1308+                                 PatternRewriter &rewriter) const  final  {
1309+     auto  collapseOp =
1310+         extractOp.getTensor ().getDefiningOp <tensor::CollapseShapeOp>();
1311+     if  (!collapseOp)
1312+       return  failure ();
1313+     if  (!collapseOp.getSrcType ().hasStaticShape ())
1314+       return  failure ();
1315+ 
1316+     auto  sourceSizes = collapseOp.getSrcType ().getShape ();
1317+ 
1318+     SmallVector<Value> indices (extractOp.getIndices ().begin (),
1319+                                extractOp.getIndices ().end ());
1320+     SmallVector<Value> sourceIndices;
1321+     for  (auto  [index, group] :
1322+          llvm::zip (indices, collapseOp.getReassociationIndices ())) {
1323+       assert (!group.empty () && " association indices groups cannot be empty"  );
1324+       auto  groupSize = group.size ();
1325+ 
1326+       if  (groupSize == 1 ) {
1327+         sourceIndices.push_back (index);
1328+         continue ;
1329+       }
1330+ 
1331+       SmallVector<int64_t > basis =
1332+           llvm::map_to_vector (group, [&](int64_t  d) { return  sourceSizes[d]; });
1333+       auto  delinearize = rewriter.create <affine::AffineDelinearizeIndexOp>(
1334+           extractOp.getLoc (), index, basis, /* hasOuterBound=*/ true );
1335+       llvm::append_range (sourceIndices, delinearize.getResults ());
1336+     }
1337+     if  (collapseOp.getReassociationIndices ().empty ()) {
1338+       auto  zeroAffineMap = rewriter.getConstantAffineMap (0 );
1339+       int64_t  srcRank =
1340+           cast<RankedTensorType>(collapseOp.getSrcType ()).getRank ();
1341+       OpFoldResult ofr = affine::makeComposedFoldedAffineApply (
1342+           rewriter, extractOp.getLoc (), zeroAffineMap,
1343+           ArrayRef<OpFoldResult>{});
1344+       for  (int64_t  i = 0 ; i < srcRank; i++) {
1345+         sourceIndices.push_back (
1346+             getValueOrCreateConstantIndexOp (rewriter, extractOp.getLoc (), ofr));
1347+       }
1348+     }
1349+ 
1350+     rewriter.replaceOpWithNewOp <tensor::ExtractOp>(
1351+         extractOp, collapseOp.getSrc (), sourceIndices);
1352+     return  success ();
1353+   }
1354+ };
1355+ 
12911356} //  namespace
12921357
12931358void  ExtractOp::getAsmResultNames (
@@ -1303,6 +1368,23 @@ LogicalResult ExtractOp::verify() {
13031368  return  success ();
13041369}
13051370
1371+ // / If we have an ExtractOp consuming an InsertOp with the same
1372+ // / indices, we can return the InsertOp's scalar directly.
1373+ //  TODO: This only checks the immediate producer; extend to go up the
1374+ //  insert/extract chain if the slices are disjoint.
1375+ static  Value foldExtractAfterInsert (ExtractOp extractOp) {
1376+   auto  insertOp = extractOp.getTensor ().getDefiningOp <InsertOp>();
1377+ 
1378+   auto  isSame = [](Value a, Value b) {
1379+     return  getAsOpFoldResult (a) == getAsOpFoldResult (b);
1380+   };
1381+   if  (insertOp && insertOp.getScalar ().getType () == extractOp.getType () &&
1382+       llvm::equal (insertOp.getIndices (), extractOp.getIndices (), isSame))
1383+     return  insertOp.getScalar ();
1384+ 
1385+   return  {};
1386+ }
1387+ 
13061388OpFoldResult ExtractOp::fold (FoldAdaptor adaptor) {
13071389  if  (Attribute tensor = adaptor.getTensor ()) {
13081390    //  If this is a splat elements attribute, simply return the value.
@@ -1350,6 +1432,9 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
13501432      return  elementsAttr.getValues <Attribute>()[indices];
13511433  }
13521434
1435+   if  (Value result = foldExtractAfterInsert (*this ))
1436+     return  result;
1437+ 
13531438  return  {};
13541439}
13551440
@@ -1358,6 +1443,11 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
13581443  results.add <ExtractFromTensorCast>(context);
13591444}
13601445
1446+ void  mlir::tensor::populateFoldCollapseExtractPatterns (
1447+     RewritePatternSet &patterns) {
1448+   patterns.add <ExtractFromCollapseShape>(patterns.getContext ());
1449+ }
1450+ 
13611451// ===----------------------------------------------------------------------===//
13621452//  FromElementsOp
13631453// ===----------------------------------------------------------------------===//
@@ -1534,6 +1624,76 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
15341624//  InsertOp
15351625// ===----------------------------------------------------------------------===//
15361626
1627+ namespace  {
1628+ 
1629+ // / Pattern to fold an insert op of a constant destination and scalar to a new
1630+ // / constant.
1631+ // /
1632+ // / Example:
1633+ // / ```
1634+ // /   %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
1635+ // /   %c0 = arith.constant 0 : index
1636+ // /   %c4_f32 = arith.constant 4.0 : f32
1637+ // /   %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32>
1638+ // / ```
1639+ // / is rewritten into:
1640+ // / ```
1641+ // /   %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
1642+ // / ```
1643+ class  InsertOpConstantFold  final  : public OpRewritePattern<InsertOp> {
1644+ public: 
1645+   using  OpRewritePattern<InsertOp>::OpRewritePattern;
1646+ 
1647+   LogicalResult matchAndRewrite (InsertOp insertOp,
1648+                                 PatternRewriter &rewriter) const  override  {
1649+     //  Requires a ranked tensor type.
1650+     auto  destType =
1651+         llvm::dyn_cast<RankedTensorType>(insertOp.getDest ().getType ());
1652+     if  (!destType)
1653+       return  failure ();
1654+ 
1655+     //  Pattern requires constant indices
1656+     SmallVector<uint64_t , 8 > indices;
1657+     for  (OpFoldResult indice : getAsOpFoldResult (insertOp.getIndices ())) {
1658+       auto  indiceAttr = dyn_cast<Attribute>(indice);
1659+       if  (!indiceAttr)
1660+         return  failure ();
1661+       indices.push_back (llvm::cast<IntegerAttr>(indiceAttr).getInt ());
1662+     }
1663+ 
1664+     //  Requires a constant scalar to insert
1665+     OpFoldResult scalar = getAsOpFoldResult (insertOp.getScalar ());
1666+     Attribute scalarAttr = dyn_cast<Attribute>(scalar);
1667+     if  (!scalarAttr)
1668+       return  failure ();
1669+ 
1670+     if  (auto  constantOp = dyn_cast_or_null<arith::ConstantOp>(
1671+             insertOp.getDest ().getDefiningOp ())) {
1672+       if  (auto  sourceAttr =
1673+               llvm::dyn_cast<ElementsAttr>(constantOp.getValue ())) {
1674+         //  Update the attribute at the inserted index.
1675+         auto  sourceValues = sourceAttr.getValues <Attribute>();
1676+         auto  flattenedIndex = sourceAttr.getFlattenedIndex (indices);
1677+         std::vector<Attribute> updatedValues;
1678+         updatedValues.reserve (sourceAttr.getNumElements ());
1679+         for  (auto  i = 0 ; i < sourceAttr.getNumElements (); ++i) {
1680+           updatedValues.push_back (i == flattenedIndex ? scalarAttr
1681+                                                       : sourceValues[i]);
1682+         }
1683+         rewriter.replaceOpWithNewOp <arith::ConstantOp>(
1684+             insertOp, sourceAttr.getType (),
1685+             DenseElementsAttr::get (cast<ShapedType>(sourceAttr.getType ()),
1686+                                    updatedValues));
1687+         return  success ();
1688+       }
1689+     }
1690+ 
1691+     return  failure ();
1692+   }
1693+ };
1694+ 
1695+ } //  namespace
1696+ 
15371697void  InsertOp::getAsmResultNames (
15381698    function_ref<void (Value, StringRef)> setNameFn) {
15391699  setNameFn (getResult (), " inserted"  );
@@ -1557,6 +1717,11 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
15571717  return  {};
15581718}
15591719
1720+ void  InsertOp::getCanonicalizationPatterns (RewritePatternSet &results,
1721+                                            MLIRContext *context) {
1722+   results.add <InsertOpConstantFold>(context);
1723+ }
1724+ 
15601725// ===----------------------------------------------------------------------===//
15611726//  GenerateOp
15621727// ===----------------------------------------------------------------------===//
0 commit comments