@@ -2031,20 +2031,71 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
20312031static Attribute foldPoisonIndexInsertExtractOp (MLIRContext *context,
20322032 ArrayRef<int64_t > staticPos,
20332033 int64_t poisonVal) {
2034- if (!llvm:: is_contained (staticPos, poisonVal))
2034+ if (!is_contained (staticPos, poisonVal))
20352035 return {};
20362036
20372037 return ub::PoisonAttr::get (context);
20382038}
20392039
20402040// / Fold a vector extract from is a poison source.
20412041static Attribute foldPoisonSrcExtractOp (Attribute srcAttr) {
2042- if (llvm:: isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2042+ if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
20432043 return srcAttr;
20442044
20452045 return {};
20462046}
20472047
2048+ // / Fold a vector extract extracting from a DenseElementsAttr.
2049+ static Attribute foldDenseElementsAttrSrcExtractOp (ExtractOp extractOp,
2050+ Attribute srcAttr) {
2051+ auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2052+ if (!denseAttr) {
2053+ return {};
2054+ }
2055+
2056+ if (denseAttr.isSplat ()) {
2057+ Attribute newAttr = denseAttr.getSplatValue <Attribute>();
2058+ if (auto vecDstType = dyn_cast<VectorType>(extractOp.getType ()))
2059+ newAttr = DenseElementsAttr::get (vecDstType, newAttr);
2060+ return newAttr;
2061+ }
2062+
2063+ auto vecTy = cast<VectorType>(extractOp.getSourceVectorType ());
2064+ if (vecTy.isScalable ())
2065+ return {};
2066+
2067+ if (extractOp.hasDynamicPosition ()) {
2068+ return {};
2069+ }
2070+
2071+ // Materializing subsets of a large constant array can generally lead to
2072+ // explosion in IR size because of different combination of subsets that
2073+ // can exist. However, vector.extract is a restricted form of subset
2074+ // extract where you can only extract non-overlapping (or the same) subset for
2075+ // a given rank of the subset. Because of this property, the IR size can only
2076+ // increase at most by `rank * size(array)` from a single constant array being
2077+ // extracted by multiple extracts.
2078+
2079+ // Calculate the linearized position of the continuous chunk of elements to
2080+ // extract.
2081+ SmallVector<int64_t > completePositions (vecTy.getRank (), 0 );
2082+ copy (extractOp.getStaticPosition (), completePositions.begin ());
2083+ int64_t startPos =
2084+ linearize (completePositions, computeStrides (vecTy.getShape ()));
2085+ auto denseValuesBegin = denseAttr.value_begin <TypedAttr>() + startPos;
2086+
2087+ TypedAttr newAttr;
2088+ if (auto resVecTy = dyn_cast<VectorType>(extractOp.getType ())) {
2089+ SmallVector<Attribute> elementValues (
2090+ denseValuesBegin, denseValuesBegin + resVecTy.getNumElements ());
2091+ newAttr = DenseElementsAttr::get (resVecTy, elementValues);
2092+ } else {
2093+ newAttr = *denseValuesBegin;
2094+ }
2095+
2096+ return newAttr;
2097+ }
2098+
20482099OpFoldResult ExtractOp::fold (FoldAdaptor adaptor) {
20492100 // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
20502101 // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -2056,6 +2107,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
20562107 return res;
20572108 if (auto res = foldPoisonSrcExtractOp (adaptor.getVector ()))
20582109 return res;
2110+ if (auto res = foldDenseElementsAttrSrcExtractOp (*this , adaptor.getVector ()))
2111+ return res;
20592112 if (succeeded (foldExtractOpFromExtractChain (*this )))
20602113 return getResult ();
20612114 if (auto res = ExtractFromInsertTransposeChainState (*this ).fold ())
@@ -2119,80 +2172,6 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
21192172 }
21202173};
21212174
2122- // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
2123- class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
2124- public:
2125- using OpRewritePattern::OpRewritePattern;
2126-
2127- LogicalResult matchAndRewrite (ExtractOp extractOp,
2128- PatternRewriter &rewriter) const override {
2129- // Return if 'ExtractOp' operand is not defined by a splat vector
2130- // ConstantOp.
2131- Value sourceVector = extractOp.getVector ();
2132- Attribute vectorCst;
2133- if (!matchPattern (sourceVector, m_Constant (&vectorCst)))
2134- return failure ();
2135- auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
2136- if (!splat)
2137- return failure ();
2138- TypedAttr newAttr = splat.getSplatValue <TypedAttr>();
2139- if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType ()))
2140- newAttr = DenseElementsAttr::get (vecDstType, newAttr);
2141- rewriter.replaceOpWithNewOp <arith::ConstantOp>(extractOp, newAttr);
2142- return success ();
2143- }
2144- };
2145-
2146- // Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
2147- class ExtractOpNonSplatConstantFolder final
2148- : public OpRewritePattern<ExtractOp> {
2149- public:
2150- using OpRewritePattern::OpRewritePattern;
2151-
2152- LogicalResult matchAndRewrite (ExtractOp extractOp,
2153- PatternRewriter &rewriter) const override {
2154- // TODO: Canonicalization for dynamic position not implemented yet.
2155- if (extractOp.hasDynamicPosition ())
2156- return failure ();
2157-
2158- // Return if 'ExtractOp' operand is not defined by a compatible vector
2159- // ConstantOp.
2160- Value sourceVector = extractOp.getVector ();
2161- Attribute vectorCst;
2162- if (!matchPattern (sourceVector, m_Constant (&vectorCst)))
2163- return failure ();
2164-
2165- auto vecTy = llvm::cast<VectorType>(sourceVector.getType ());
2166- if (vecTy.isScalable ())
2167- return failure ();
2168-
2169- // The splat case is handled by `ExtractOpSplatConstantFolder`.
2170- auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
2171- if (!dense || dense.isSplat ())
2172- return failure ();
2173-
2174- // Calculate the linearized position of the continuous chunk of elements to
2175- // extract.
2176- llvm::SmallVector<int64_t > completePositions (vecTy.getRank (), 0 );
2177- copy (extractOp.getStaticPosition (), completePositions.begin ());
2178- int64_t elemBeginPosition =
2179- linearize (completePositions, computeStrides (vecTy.getShape ()));
2180- auto denseValuesBegin = dense.value_begin <TypedAttr>() + elemBeginPosition;
2181-
2182- TypedAttr newAttr;
2183- if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType ())) {
2184- SmallVector<Attribute> elementValues (
2185- denseValuesBegin, denseValuesBegin + resVecTy.getNumElements ());
2186- newAttr = DenseElementsAttr::get (resVecTy, elementValues);
2187- } else {
2188- newAttr = *denseValuesBegin;
2189- }
2190-
2191- rewriter.replaceOpWithNewOp <arith::ConstantOp>(extractOp, newAttr);
2192- return success ();
2193- }
2194- };
2195-
21962175// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
21972176class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
21982177public:
@@ -2330,8 +2309,7 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
23302309
23312310void ExtractOp::getCanonicalizationPatterns (RewritePatternSet &results,
23322311 MLIRContext *context) {
2333- results.add <ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2334- ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2312+ results.add <ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
23352313 results.add (foldExtractFromShapeCastToShapeCast);
23362314 results.add (foldExtractFromFromElements);
23372315}
0 commit comments