@@ -1624,76 +1624,6 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
16241624// InsertOp
16251625// ===----------------------------------------------------------------------===//
16261626
1627- namespace {
1628-
1629- // / Pattern to fold an insert op of a constant destination and scalar to a new
1630- // / constant.
1631- // /
1632- // / Example:
1633- // / ```
1634- // / %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
1635- // / %c0 = arith.constant 0 : index
1636- // / %c4_f32 = arith.constant 4.0 : f32
1637- // / %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32>
1638- // / ```
1639- // / is rewritten into:
1640- // / ```
1641- // / %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
1642- // / ```
1643- class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
1644- public:
1645- using OpRewritePattern<InsertOp>::OpRewritePattern;
1646-
1647- LogicalResult matchAndRewrite (InsertOp insertOp,
1648- PatternRewriter &rewriter) const override {
1649- // Requires a ranked tensor type.
1650- auto destType =
1651- llvm::dyn_cast<RankedTensorType>(insertOp.getDest ().getType ());
1652- if (!destType)
1653- return failure ();
1654-
1655- // Pattern requires constant indices
1656- SmallVector<uint64_t , 8 > indices;
1657- for (OpFoldResult indice : getAsOpFoldResult (insertOp.getIndices ())) {
1658- auto indiceAttr = dyn_cast<Attribute>(indice);
1659- if (!indiceAttr)
1660- return failure ();
1661- indices.push_back (llvm::cast<IntegerAttr>(indiceAttr).getInt ());
1662- }
1663-
1664- // Requires a constant scalar to insert
1665- OpFoldResult scalar = getAsOpFoldResult (insertOp.getScalar ());
1666- Attribute scalarAttr = dyn_cast<Attribute>(scalar);
1667- if (!scalarAttr)
1668- return failure ();
1669-
1670- if (auto constantOp = dyn_cast_or_null<arith::ConstantOp>(
1671- insertOp.getDest ().getDefiningOp ())) {
1672- if (auto sourceAttr =
1673- llvm::dyn_cast<ElementsAttr>(constantOp.getValue ())) {
1674- // Update the attribute at the inserted index.
1675- auto sourceValues = sourceAttr.getValues <Attribute>();
1676- auto flattenedIndex = sourceAttr.getFlattenedIndex (indices);
1677- std::vector<Attribute> updatedValues;
1678- updatedValues.reserve (sourceAttr.getNumElements ());
1679- for (unsigned i = 0 ; i < sourceAttr.getNumElements (); ++i) {
1680- updatedValues.push_back (i == flattenedIndex ? scalarAttr
1681- : sourceValues[i]);
1682- }
1683- rewriter.replaceOpWithNewOp <arith::ConstantOp>(
1684- insertOp, sourceAttr.getType (),
1685- DenseElementsAttr::get (cast<ShapedType>(sourceAttr.getType ()),
1686- updatedValues));
1687- return success ();
1688- }
1689- }
1690-
1691- return failure ();
1692- }
1693- };
1694-
1695- } // namespace
1696-
16971627void InsertOp::getAsmResultNames (
16981628 function_ref<void (Value, StringRef)> setNameFn) {
16991629 setNameFn (getResult (), " inserted" );
@@ -1717,11 +1647,6 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
17171647 return {};
17181648}
17191649
1720- void InsertOp::getCanonicalizationPatterns (RewritePatternSet &results,
1721- MLIRContext *context) {
1722- results.add <InsertOpConstantFold>(context);
1723- }
1724-
17251650// ===----------------------------------------------------------------------===//
17261651// GenerateOp
17271652// ===----------------------------------------------------------------------===//
0 commit comments