Skip to content

Commit 3007c1d

Browse files
committed
Remove subview
1 parent 70adf3a commit 3007c1d

File tree

2 files changed

+1
-71
lines changed

2 files changed

+1
-71
lines changed

mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp

Lines changed: 1 addition & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -182,66 +182,6 @@ struct MemRefRewritePattern : public OpRewritePattern<T> {
182182
}
183183
};
184184

185-
struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
186-
using OpRewritePattern::OpRewritePattern;
187-
188-
LogicalResult matchAndRewrite(memref::SubViewOp op,
189-
PatternRewriter &rewriter) const override {
190-
Value memref = op.getSource();
191-
if (!needFlattening(memref))
192-
return rewriter.notifyMatchFailure(op, "already flattened");
193-
194-
if (!checkLayout(memref))
195-
return rewriter.notifyMatchFailure(op, "unsupported layout");
196-
197-
Location loc = op.getLoc();
198-
SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
199-
SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
200-
SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
201-
202-
// base, finalOffset, strides
203-
memref::ExtractStridedMetadataOp stridedMetadata =
204-
rewriter.create<memref::ExtractStridedMetadataOp>(loc, memref);
205-
206-
auto sourceType = cast<MemRefType>(memref.getType());
207-
auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
208-
OpFoldResult linearizedIndices;
209-
memref::LinearizedMemRefInfo linearizedInfo;
210-
std::tie(linearizedInfo, linearizedIndices) =
211-
memref::getLinearizedMemRefOffsetAndSize(
212-
rewriter, loc, typeBit, typeBit,
213-
stridedMetadata.getConstifiedMixedOffset(),
214-
stridedMetadata.getConstifiedMixedSizes(),
215-
stridedMetadata.getConstifiedMixedStrides(), op.getMixedOffsets());
216-
auto finalOffset = linearizedInfo.linearizedOffset;
217-
auto strides = stridedMetadata.getConstifiedMixedStrides();
218-
219-
auto srcType = cast<MemRefType>(memref.getType());
220-
auto resultType = cast<MemRefType>(op.getType());
221-
unsigned subRank = static_cast<unsigned>(resultType.getRank());
222-
223-
llvm::SmallBitVector droppedDims = op.getDroppedDims();
224-
225-
SmallVector<OpFoldResult> finalSizes;
226-
finalSizes.reserve(subRank);
227-
228-
SmallVector<OpFoldResult> finalStrides;
229-
finalStrides.reserve(subRank);
230-
231-
for (auto i : llvm::seq(0u, static_cast<unsigned>(srcType.getRank()))) {
232-
if (droppedDims.test(i))
233-
continue;
234-
235-
finalSizes.push_back(subSizes[i]);
236-
finalStrides.push_back(strides[i]);
237-
}
238-
239-
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
240-
op, resultType, memref, finalOffset, finalSizes, finalStrides);
241-
return success();
242-
}
243-
};
244-
245185
struct FlattenMemrefsPass
246186
: public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
247187
using Base::Base;
@@ -271,6 +211,6 @@ void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
271211
MemRefRewritePattern<vector::TransferReadOp>,
272212
MemRefRewritePattern<vector::TransferWriteOp>,
273213
MemRefRewritePattern<vector::MaskedLoadOp>,
274-
MemRefRewritePattern<vector::MaskedStoreOp>, FlattenSubview>(
214+
MemRefRewritePattern<vector::MaskedStoreOp>>(
275215
patterns.getContext());
276216
}

mlir/test/Dialect/MemRef/flatten_memref.mlir

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,6 @@ func.func @load_scalar_from_memref_dynamic_dim(%input: memref<?x?xf32, strided<[
4646

4747
// -----
4848

49-
func.func @load_scalar_from_memref_subview(%input: memref<4x8xf32>, %row: index, %col: index) -> memref<1x1xf32, strided<[8, 1], offset: ?>> {
50-
%subview = memref.subview %input[%col, %row] [1, 1] [1, 1] : memref<4x8xf32> to memref<1x1xf32, strided<[8, 1], offset: ?>>
51-
return %subview : memref<1x1xf32, strided<[8, 1], offset: ?>>
52-
}
53-
// CHECK-LABEL: func @load_scalar_from_memref_subview
54-
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
55-
// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [1, 1], strides: [8, 1]
56-
57-
// -----
58-
5949
func.func @store_scalar_from_memref_static_dim(%input: memref<4x8xf32, strided<[8, 12], offset: 100>>, %row: index, %col: index, %value: f32) {
6050
memref.store %value, %input[%col, %row] : memref<4x8xf32, strided<[8, 12], offset: 100>>
6151
return

0 commit comments

Comments
 (0)