@@ -182,66 +182,6 @@ struct MemRefRewritePattern : public OpRewritePattern<T> {
182182 }
183183};
184184
185- struct FlattenSubview : public OpRewritePattern <memref::SubViewOp> {
186- using OpRewritePattern::OpRewritePattern;
187-
188- LogicalResult matchAndRewrite (memref::SubViewOp op,
189- PatternRewriter &rewriter) const override {
190- Value memref = op.getSource ();
191- if (!needFlattening (memref))
192- return rewriter.notifyMatchFailure (op, " already flattened" );
193-
194- if (!checkLayout (memref))
195- return rewriter.notifyMatchFailure (op, " unsupported layout" );
196-
197- Location loc = op.getLoc ();
198- SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets ();
199- SmallVector<OpFoldResult> subSizes = op.getMixedSizes ();
200- SmallVector<OpFoldResult> subStrides = op.getMixedStrides ();
201-
202- // base, finalOffset, strides
203- memref::ExtractStridedMetadataOp stridedMetadata =
204- rewriter.create <memref::ExtractStridedMetadataOp>(loc, memref);
205-
206- auto sourceType = cast<MemRefType>(memref.getType ());
207- auto typeBit = sourceType.getElementType ().getIntOrFloatBitWidth ();
208- OpFoldResult linearizedIndices;
209- memref::LinearizedMemRefInfo linearizedInfo;
210- std::tie (linearizedInfo, linearizedIndices) =
211- memref::getLinearizedMemRefOffsetAndSize (
212- rewriter, loc, typeBit, typeBit,
213- stridedMetadata.getConstifiedMixedOffset (),
214- stridedMetadata.getConstifiedMixedSizes (),
215- stridedMetadata.getConstifiedMixedStrides (), op.getMixedOffsets ());
216- auto finalOffset = linearizedInfo.linearizedOffset ;
217- auto strides = stridedMetadata.getConstifiedMixedStrides ();
218-
219- auto srcType = cast<MemRefType>(memref.getType ());
220- auto resultType = cast<MemRefType>(op.getType ());
221- unsigned subRank = static_cast <unsigned >(resultType.getRank ());
222-
223- llvm::SmallBitVector droppedDims = op.getDroppedDims ();
224-
225- SmallVector<OpFoldResult> finalSizes;
226- finalSizes.reserve (subRank);
227-
228- SmallVector<OpFoldResult> finalStrides;
229- finalStrides.reserve (subRank);
230-
231- for (auto i : llvm::seq (0u , static_cast <unsigned >(srcType.getRank ()))) {
232- if (droppedDims.test (i))
233- continue ;
234-
235- finalSizes.push_back (subSizes[i]);
236- finalStrides.push_back (strides[i]);
237- }
238-
239- rewriter.replaceOpWithNewOp <memref::ReinterpretCastOp>(
240- op, resultType, memref, finalOffset, finalSizes, finalStrides);
241- return success ();
242- }
243- };
244-
245185struct FlattenMemrefsPass
246186 : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
247187 using Base::Base;
@@ -271,6 +211,6 @@ void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
271211 MemRefRewritePattern<vector::TransferReadOp>,
272212 MemRefRewritePattern<vector::TransferWriteOp>,
273213 MemRefRewritePattern<vector::MaskedLoadOp>,
274- MemRefRewritePattern<vector::MaskedStoreOp>, FlattenSubview >(
214+ MemRefRewritePattern<vector::MaskedStoreOp>>(
275215 patterns.getContext ());
276216}
0 commit comments