Skip to content

Commit 05cdabb

Browse files
committed
subview
1 parent d5cda23 commit 05cdabb

File tree

2 files changed

+152
-42
lines changed

2 files changed

+152
-42
lines changed

mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp

Lines changed: 139 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir/IR/OpDefinition.h"
2828
#include "mlir/IR/PatternMatch.h"
2929
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30+
#include "llvm/ADT/STLExtras.h"
3031
#include "llvm/ADT/TypeSwitch.h"
3132

3233
namespace mlir {
@@ -47,6 +48,7 @@ static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
4748
return cast<Value>(in);
4849
}
4950

51+
5052
/// Returns a collapsed memref and the linearized index to access the element
5153
/// at the specified indices.
5254
static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
@@ -90,12 +92,15 @@ static bool needFlattening(Value val) {
9092
return type.getRank() > 1;
9193
}
9294

93-
static bool checkLayout(Value val) {
94-
auto type = cast<MemRefType>(val.getType());
95+
static bool checkLayout(MemRefType type) {
9596
return type.getLayout().isIdentity() ||
9697
isa<StridedLayoutAttr>(type.getLayout());
9798
}
9899

100+
static bool checkLayout(Value val) {
101+
return checkLayout(cast<MemRefType>(val.getType()));
102+
}
103+
99104
namespace {
100105
static Value getTargetMemref(Operation *op) {
101106
return llvm::TypeSwitch<Operation *, Value>(op)
@@ -368,38 +373,131 @@ struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp>
368373
};
369374

370375

371-
/*
372-
// Flattens memref subspan ops with more than 1 dimensions to 1 dimension.
373-
struct FlattenSubView final : public OpConversionPattern<memref::SubViewOp> {
374-
using OpConversionPattern::OpConversionPattern;
376+
// Flattens memref subview ops with more than 1 dimension into 1-D accesses.
377+
struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
378+
using OpRewritePattern::OpRewritePattern;
375379

376-
LogicalResult
377-
matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
378-
ConversionPatternRewriter &rewriter) const override {
379-
if (!isRankZeroOrOneMemRef(adaptor.getSource().getType())) {
380-
return rewriter.notifyMatchFailure(
381-
op, "expected converted memref of rank <= 1");
382-
}
383-
Type neededResultType =
384-
getTypeConverter()->convertType(op.getResult().getType());
385-
if (!neededResultType || !isRankZeroOrOneMemRef(neededResultType))
380+
LogicalResult matchAndRewrite(memref::SubViewOp op,
381+
PatternRewriter &rewriter) const override {
382+
auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
383+
if (!sourceType || sourceType.getRank() <= 1)
384+
return failure();
385+
if (!checkLayout(sourceType))
386386
return failure();
387-
Value size = createTotalElementCountValue(op.getType(), op.getSizes(),
388-
op.getLoc(), rewriter);
389-
SmallVector<Value> offsets = mlir::getValueOrCreateConstantIndexOp(
390-
rewriter, op.getLoc(), op.getMixedOffsets());
391-
Value linearOffset =
392-
linearizeIndices(op.getSource(), offsets, op.getLoc(), rewriter);
393-
Value stride = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 1);
394-
Value newSubView = memref::SubViewOp::create(
395-
rewriter, op.getLoc(), adaptor.getSource(), ValueRange({linearOffset}),
396-
ValueRange({size}), ValueRange({stride}));
397-
rewriter.replaceOpWithNewOp<memref::CastOp>(op, neededResultType,
398-
newSubView);
387+
388+
MemRefType resultType = op.getType();
389+
if (resultType.getRank() <= 1 || !checkLayout(resultType))
390+
return failure();
391+
392+
unsigned elementBitWidth = sourceType.getElementTypeBitWidth();
393+
if (!elementBitWidth)
394+
return failure();
395+
396+
Location loc = op.getLoc();
397+
398+
// Materialize offsets as values so they can participate in linearization.
399+
SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
400+
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
401+
SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
402+
403+
SmallVector<Value> offsetValues;
404+
offsetValues.reserve(mixedOffsets.size());
405+
for (OpFoldResult ofr : mixedOffsets)
406+
offsetValues.push_back(getValueFromOpFoldResult(rewriter, loc, ofr));
407+
408+
auto [flatSource, linearOffset] =
409+
getFlattenMemrefAndOffset(rewriter, loc, op.getSource(),
410+
ValueRange(offsetValues));
411+
412+
memref::ExtractStridedMetadataOp sourceMetadata =
413+
memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSource());
414+
415+
SmallVector<OpFoldResult> sourceStrides =
416+
sourceMetadata.getConstifiedMixedStrides();
417+
OpFoldResult sourceOffset = sourceMetadata.getConstifiedMixedOffset();
418+
419+
llvm::SmallBitVector droppedDims = op.getDroppedDims();
420+
421+
SmallVector<OpFoldResult> resultSizes;
422+
SmallVector<OpFoldResult> resultStrides;
423+
resultSizes.reserve(resultType.getRank());
424+
resultStrides.reserve(resultType.getRank());
425+
426+
OpFoldResult resultOffset = sourceOffset;
427+
for (auto [idx, it] : llvm::enumerate(llvm::zip_equal(
428+
mixedOffsets, sourceStrides, mixedSizes, mixedStrides))) {
429+
auto [offsetOfr, strideOfr, sizeOfr, relativeStrideOfr] = it;
430+
OpFoldResult contribution = [&]() -> OpFoldResult {
431+
if (Attribute offsetAttr = dyn_cast<Attribute>(offsetOfr)) {
432+
if (Attribute strideAttr = dyn_cast<Attribute>(strideOfr)) {
433+
auto offsetInt = cast<IntegerAttr>(offsetAttr).getInt();
434+
auto strideInt = cast<IntegerAttr>(strideAttr).getInt();
435+
return rewriter.getIndexAttr(offsetInt * strideInt);
436+
}
437+
}
438+
Value offsetVal = getValueFromOpFoldResult(rewriter, loc, offsetOfr);
439+
Value strideVal = getValueFromOpFoldResult(rewriter, loc, strideOfr);
440+
return rewriter.create<arith::MulIOp>(loc, offsetVal, strideVal)
441+
.getResult();
442+
}();
443+
resultOffset = [&]() -> OpFoldResult {
444+
if (Attribute offsetAttr = dyn_cast<Attribute>(resultOffset)) {
445+
if (Attribute contribAttr = dyn_cast<Attribute>(contribution)) {
446+
auto offsetInt = cast<IntegerAttr>(offsetAttr).getInt();
447+
auto contribInt = cast<IntegerAttr>(contribAttr).getInt();
448+
return rewriter.getIndexAttr(offsetInt + contribInt);
449+
}
450+
}
451+
Value offsetVal = getValueFromOpFoldResult(rewriter, loc, resultOffset);
452+
Value contribVal = getValueFromOpFoldResult(rewriter, loc, contribution);
453+
return rewriter.create<arith::AddIOp>(loc, offsetVal, contribVal)
454+
.getResult();
455+
}();
456+
457+
if (droppedDims.test(idx))
458+
continue;
459+
460+
resultSizes.push_back(sizeOfr);
461+
OpFoldResult combinedStride = [&]() -> OpFoldResult {
462+
if (Attribute relStrideAttr = dyn_cast<Attribute>(relativeStrideOfr)) {
463+
if (Attribute strideAttr = dyn_cast<Attribute>(strideOfr)) {
464+
auto relStrideInt = cast<IntegerAttr>(relStrideAttr).getInt();
465+
auto strideInt = cast<IntegerAttr>(strideAttr).getInt();
466+
return rewriter.getIndexAttr(relStrideInt * strideInt);
467+
}
468+
}
469+
Value relStrideVal =
470+
getValueFromOpFoldResult(rewriter, loc, relativeStrideOfr);
471+
Value strideVal = getValueFromOpFoldResult(rewriter, loc, strideOfr);
472+
return rewriter.create<arith::MulIOp>(loc, relStrideVal, strideVal)
473+
.getResult();
474+
}();
475+
resultStrides.push_back(combinedStride);
476+
}
477+
478+
memref::LinearizedMemRefInfo linearizedInfo;
479+
[[maybe_unused]] OpFoldResult linearizedIndex;
480+
std::tie(linearizedInfo, linearizedIndex) =
481+
memref::getLinearizedMemRefOffsetAndSize(
482+
rewriter, loc, elementBitWidth, elementBitWidth, resultOffset,
483+
resultSizes, resultStrides);
484+
485+
Value flattenedSize = getValueFromOpFoldResult(
486+
rewriter, loc, linearizedInfo.linearizedSize);
487+
Value strideOne = arith::ConstantIndexOp::create(rewriter, loc, 1);
488+
489+
Value flattenedSubview = memref::SubViewOp::create(
490+
rewriter, loc, flatSource, ValueRange{linearOffset},
491+
ValueRange{flattenedSize}, ValueRange{strideOne});
492+
493+
Value replacement = memref::ReinterpretCastOp::create(
494+
rewriter, loc, resultType, flattenedSubview, resultOffset, resultSizes,
495+
resultStrides);
496+
497+
rewriter.replaceOp(op, replacement);
399498
return success();
400499
}
401500
};
402-
*/
403501

404502
struct FlattenMemrefsPass
405503
: public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
@@ -422,18 +520,6 @@ struct FlattenMemrefsPass
422520

423521
} // namespace
424522

425-
void memref::populateFlattenVectorOpsOnMemrefPatterns(
426-
RewritePatternSet &patterns) {
427-
patterns.insert<MemRefRewritePattern<vector::LoadOp>,
428-
MemRefRewritePattern<vector::StoreOp>,
429-
MemRefRewritePattern<vector::TransferReadOp>,
430-
MemRefRewritePattern<vector::TransferWriteOp>,
431-
MemRefRewritePattern<vector::MaskedLoadOp>,
432-
MemRefRewritePattern<vector::MaskedStoreOp>>(
433-
patterns.getContext());
434-
}
435-
436-
/// Special pattern for GetGlobalOp to avoid infinite loops
437523
struct FlattenGetGlobal : public OpRewritePattern<memref::GetGlobalOp> {
438524
using OpRewritePattern::OpRewritePattern;
439525

@@ -470,6 +556,17 @@ struct FlattenGetGlobal : public OpRewritePattern<memref::GetGlobalOp> {
470556
}
471557
};
472558

559+
void memref::populateFlattenVectorOpsOnMemrefPatterns(
560+
RewritePatternSet &patterns) {
561+
patterns.insert<MemRefRewritePattern<vector::LoadOp>,
562+
MemRefRewritePattern<vector::StoreOp>,
563+
MemRefRewritePattern<vector::TransferReadOp>,
564+
MemRefRewritePattern<vector::TransferWriteOp>,
565+
MemRefRewritePattern<vector::MaskedLoadOp>,
566+
MemRefRewritePattern<vector::MaskedStoreOp>>(
567+
patterns.getContext());
568+
}
569+
473570
void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
474571
patterns.insert<MemRefRewritePattern<memref::LoadOp>,
475572
MemRefRewritePattern<memref::StoreOp>,
@@ -478,7 +575,7 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
478575
MemRefRewritePattern<memref::DeallocOp>,
479576
FlattenExpandShape,
480577
FlattenCollapseShape,
481-
//FlattenSubView,
578+
FlattenSubView,
482579
FlattenGetGlobal,
483580
FlattenGlobal>(
484581
patterns.getContext());

mlir/test/Dialect/MemRef/flatten_memref.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,19 @@ func.func @mask_load_vector_from_memref_dynamic(%input: memref<3x7xi2>, %row: in
194194

195195
// -----
196196

197+
func.func @flatten_subview_static(%arg0: memref<3x4xf32, strided<[4, 1], offset: 0>>) -> memref<2x2xf32, strided<[4, 1], offset: 1>> {
198+
%sub = memref.subview %arg0[0, 1] [2, 2] [1, 1]
199+
: memref<3x4xf32, strided<[4, 1], offset: 0>> to memref<2x2xf32, strided<[4, 1], offset: 1>>
200+
return %sub : memref<2x2xf32, strided<[4, 1], offset: 1>>
201+
}
202+
// CHECK-LABEL: func @flatten_subview_static
203+
// CHECK: %[[C8:.*]] = arith.constant 8 : index
204+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
205+
// CHECK: %[[FLAT:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [12], strides: [1]
206+
// CHECK: %[[SUB:.*]] = memref.subview %[[FLAT]][%[[C1]]] [%[[C8]]] [%[[C1]]]
207+
// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[SUB]] to offset: [1], sizes: [2, 2], strides: [4, 1]
208+
// CHECK: return %[[CAST]]
209+
197210
func.func @collapse_shape_static(%arg0: memref<2x3x4xf32>) -> memref<6x4xf32> {
198211
%0 = memref.collapse_shape %arg0 [[0, 1], [2]]
199212
: memref<2x3x4xf32> into memref<6x4xf32>

0 commit comments

Comments
 (0)