@@ -2047,6 +2047,49 @@ static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) {
20472047 return {};
20482048}
20492049
2050+ static Attribute foldDenseElementsAttrSrcExtractOp (ExtractOp extractOp,
2051+ Attribute srcAttr) {
2052+ auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2053+ if (!denseAttr) {
2054+ return {};
2055+ }
2056+
2057+ if (denseAttr.isSplat ()) {
2058+ Attribute newAttr = denseAttr.getSplatValue <Attribute>();
2059+ if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType ()))
2060+ newAttr = DenseElementsAttr::get (vecDstType, newAttr);
2061+ return newAttr;
2062+ }
2063+
2064+ auto vecTy = llvm::cast<VectorType>(extractOp.getSourceVectorType ());
2065+ if (vecTy.isScalable ())
2066+ return {};
2067+
2068+ if (extractOp.hasDynamicPosition ()) {
2069+ return {};
2070+ }
2071+
2072+ // Calculate the linearized position of the continuous chunk of elements to
2073+ // extract.
2074+ llvm::SmallVector<int64_t > completePositions (vecTy.getRank (), 0 );
2075+ copy (extractOp.getStaticPosition (), completePositions.begin ());
2076+ int64_t elemBeginPosition =
2077+ linearize (completePositions, computeStrides (vecTy.getShape ()));
2078+ auto denseValuesBegin =
2079+ denseAttr.value_begin <TypedAttr>() + elemBeginPosition;
2080+
2081+ TypedAttr newAttr;
2082+ if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType ())) {
2083+ SmallVector<Attribute> elementValues (
2084+ denseValuesBegin, denseValuesBegin + resVecTy.getNumElements ());
2085+ newAttr = DenseElementsAttr::get (resVecTy, elementValues);
2086+ } else {
2087+ newAttr = *denseValuesBegin;
2088+ }
2089+
2090+ return newAttr;
2091+ }
2092+
20502093OpFoldResult ExtractOp::fold (FoldAdaptor adaptor) {
20512094 // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
20522095 // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -2058,6 +2101,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
20582101 return res;
20592102 if (auto res = foldPoisonSrcExtractOp (adaptor.getVector ()))
20602103 return res;
2104+ if (auto res = foldDenseElementsAttrSrcExtractOp (*this , adaptor.getVector ()))
2105+ return res;
20612106 if (succeeded (foldExtractOpFromExtractChain (*this )))
20622107 return getResult ();
20632108 if (auto res = ExtractFromInsertTransposeChainState (*this ).fold ())
@@ -2121,80 +2166,6 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
21212166 }
21222167};
21232168
2124- // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
2125- class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
2126- public:
2127- using OpRewritePattern::OpRewritePattern;
2128-
2129- LogicalResult matchAndRewrite (ExtractOp extractOp,
2130- PatternRewriter &rewriter) const override {
2131- // Return if 'ExtractOp' operand is not defined by a splat vector
2132- // ConstantOp.
2133- Value sourceVector = extractOp.getVector ();
2134- Attribute vectorCst;
2135- if (!matchPattern (sourceVector, m_Constant (&vectorCst)))
2136- return failure ();
2137- auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
2138- if (!splat)
2139- return failure ();
2140- TypedAttr newAttr = splat.getSplatValue <TypedAttr>();
2141- if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType ()))
2142- newAttr = DenseElementsAttr::get (vecDstType, newAttr);
2143- rewriter.replaceOpWithNewOp <arith::ConstantOp>(extractOp, newAttr);
2144- return success ();
2145- }
2146- };
2147-
2148- // Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
2149- class ExtractOpNonSplatConstantFolder final
2150- : public OpRewritePattern<ExtractOp> {
2151- public:
2152- using OpRewritePattern::OpRewritePattern;
2153-
2154- LogicalResult matchAndRewrite (ExtractOp extractOp,
2155- PatternRewriter &rewriter) const override {
2156- // TODO: Canonicalization for dynamic position not implemented yet.
2157- if (extractOp.hasDynamicPosition ())
2158- return failure ();
2159-
2160- // Return if 'ExtractOp' operand is not defined by a compatible vector
2161- // ConstantOp.
2162- Value sourceVector = extractOp.getVector ();
2163- Attribute vectorCst;
2164- if (!matchPattern (sourceVector, m_Constant (&vectorCst)))
2165- return failure ();
2166-
2167- auto vecTy = llvm::cast<VectorType>(sourceVector.getType ());
2168- if (vecTy.isScalable ())
2169- return failure ();
2170-
2171- // The splat case is handled by `ExtractOpSplatConstantFolder`.
2172- auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
2173- if (!dense || dense.isSplat ())
2174- return failure ();
2175-
2176- // Calculate the linearized position of the continuous chunk of elements to
2177- // extract.
2178- llvm::SmallVector<int64_t > completePositions (vecTy.getRank (), 0 );
2179- copy (extractOp.getStaticPosition (), completePositions.begin ());
2180- int64_t elemBeginPosition =
2181- linearize (completePositions, computeStrides (vecTy.getShape ()));
2182- auto denseValuesBegin = dense.value_begin <TypedAttr>() + elemBeginPosition;
2183-
2184- TypedAttr newAttr;
2185- if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType ())) {
2186- SmallVector<Attribute> elementValues (
2187- denseValuesBegin, denseValuesBegin + resVecTy.getNumElements ());
2188- newAttr = DenseElementsAttr::get (resVecTy, elementValues);
2189- } else {
2190- newAttr = *denseValuesBegin;
2191- }
2192-
2193- rewriter.replaceOpWithNewOp <arith::ConstantOp>(extractOp, newAttr);
2194- return success ();
2195- }
2196- };
2197-
21982169// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
21992170class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
22002171public:
@@ -2332,8 +2303,7 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
23322303
23332304void ExtractOp::getCanonicalizationPatterns (RewritePatternSet &results,
23342305 MLIRContext *context) {
2335- results.add <ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2336- ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2306+ results.add <ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
23372307 results.add (foldExtractFromShapeCastToShapeCast);
23382308 results.add (foldExtractFromFromElements);
23392309}
0 commit comments