2727#include " mlir/IR/OpDefinition.h"
2828#include " mlir/IR/PatternMatch.h"
2929#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
30+ #include " llvm/ADT/STLExtras.h"
3031#include " llvm/ADT/TypeSwitch.h"
3132
3233namespace mlir {
@@ -47,6 +48,7 @@ static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
4748 return cast<Value>(in);
4849}
4950
51+
5052// / Returns a collapsed memref and the linearized index to access the element
5153// / at the specified indices.
5254static std::pair<Value, Value> getFlattenMemrefAndOffset (OpBuilder &rewriter,
@@ -90,12 +92,15 @@ static bool needFlattening(Value val) {
9092 return type.getRank () > 1 ;
9193}
9294
93- static bool checkLayout (Value val) {
94- auto type = cast<MemRefType>(val.getType ());
95+ static bool checkLayout (MemRefType type) {
9596 return type.getLayout ().isIdentity () ||
9697 isa<StridedLayoutAttr>(type.getLayout ());
9798}
9899
100+ static bool checkLayout (Value val) {
101+ return checkLayout (cast<MemRefType>(val.getType ()));
102+ }
103+
99104namespace {
100105static Value getTargetMemref (Operation *op) {
101106 return llvm::TypeSwitch<Operation *, Value>(op)
@@ -368,38 +373,131 @@ struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp>
368373};
369374
370375
371- /*
372- // Flattens memref subspan ops with more than 1 dimensions to 1 dimension.
373- struct FlattenSubView final : public OpConversionPattern<memref::SubViewOp> {
374- using OpConversionPattern::OpConversionPattern;
376+ // Flattens memref subview ops with more than 1 dimension into 1-D accesses.
377+ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
378+ using OpRewritePattern::OpRewritePattern;
375379
376- LogicalResult
377- matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
378- ConversionPatternRewriter &rewriter) const override {
379- if (!isRankZeroOrOneMemRef(adaptor.getSource().getType())) {
380- return rewriter.notifyMatchFailure(
381- op, "expected converted memref of rank <= 1");
382- }
383- Type neededResultType =
384- getTypeConverter()->convertType(op.getResult().getType());
385- if (!neededResultType || !isRankZeroOrOneMemRef(neededResultType))
380+ LogicalResult matchAndRewrite (memref::SubViewOp op,
381+ PatternRewriter &rewriter) const override {
382+ auto sourceType = dyn_cast<MemRefType>(op.getSource ().getType ());
383+ if (!sourceType || sourceType.getRank () <= 1 )
384+ return failure ();
385+ if (!checkLayout (sourceType))
386386 return failure ();
387- Value size = createTotalElementCountValue(op.getType(), op.getSizes(),
388- op.getLoc(), rewriter);
389- SmallVector<Value> offsets = mlir::getValueOrCreateConstantIndexOp(
390- rewriter, op.getLoc(), op.getMixedOffsets());
391- Value linearOffset =
392- linearizeIndices(op.getSource(), offsets, op.getLoc(), rewriter);
393- Value stride = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 1);
394- Value newSubView = memref::SubViewOp::create(
395- rewriter, op.getLoc(), adaptor.getSource(), ValueRange({linearOffset}),
396- ValueRange({size}), ValueRange({stride}));
397- rewriter.replaceOpWithNewOp<memref::CastOp>(op, neededResultType,
398- newSubView);
387+
388+ MemRefType resultType = op.getType ();
389+ if (resultType.getRank () <= 1 || !checkLayout (resultType))
390+ return failure ();
391+
392+ unsigned elementBitWidth = sourceType.getElementTypeBitWidth ();
393+ if (!elementBitWidth)
394+ return failure ();
395+
396+ Location loc = op.getLoc ();
397+
398+ // Materialize offsets as values so they can participate in linearization.
399+ SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets ();
400+ SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes ();
401+ SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides ();
402+
403+ SmallVector<Value> offsetValues;
404+ offsetValues.reserve (mixedOffsets.size ());
405+ for (OpFoldResult ofr : mixedOffsets)
406+ offsetValues.push_back (getValueFromOpFoldResult (rewriter, loc, ofr));
407+
408+ auto [flatSource, linearOffset] =
409+ getFlattenMemrefAndOffset (rewriter, loc, op.getSource (),
410+ ValueRange (offsetValues));
411+
412+ memref::ExtractStridedMetadataOp sourceMetadata =
413+ memref::ExtractStridedMetadataOp::create (rewriter, loc, op.getSource ());
414+
415+ SmallVector<OpFoldResult> sourceStrides =
416+ sourceMetadata.getConstifiedMixedStrides ();
417+ OpFoldResult sourceOffset = sourceMetadata.getConstifiedMixedOffset ();
418+
419+ llvm::SmallBitVector droppedDims = op.getDroppedDims ();
420+
421+ SmallVector<OpFoldResult> resultSizes;
422+ SmallVector<OpFoldResult> resultStrides;
423+ resultSizes.reserve (resultType.getRank ());
424+ resultStrides.reserve (resultType.getRank ());
425+
426+ OpFoldResult resultOffset = sourceOffset;
427+ for (auto [idx, it] : llvm::enumerate (llvm::zip_equal (
428+ mixedOffsets, sourceStrides, mixedSizes, mixedStrides))) {
429+ auto [offsetOfr, strideOfr, sizeOfr, relativeStrideOfr] = it;
430+ OpFoldResult contribution = [&]() -> OpFoldResult {
431+ if (Attribute offsetAttr = dyn_cast<Attribute>(offsetOfr)) {
432+ if (Attribute strideAttr = dyn_cast<Attribute>(strideOfr)) {
433+ auto offsetInt = cast<IntegerAttr>(offsetAttr).getInt ();
434+ auto strideInt = cast<IntegerAttr>(strideAttr).getInt ();
435+ return rewriter.getIndexAttr (offsetInt * strideInt);
436+ }
437+ }
438+ Value offsetVal = getValueFromOpFoldResult (rewriter, loc, offsetOfr);
439+ Value strideVal = getValueFromOpFoldResult (rewriter, loc, strideOfr);
440+ return rewriter.create <arith::MulIOp>(loc, offsetVal, strideVal)
441+ .getResult ();
442+ }();
443+ resultOffset = [&]() -> OpFoldResult {
444+ if (Attribute offsetAttr = dyn_cast<Attribute>(resultOffset)) {
445+ if (Attribute contribAttr = dyn_cast<Attribute>(contribution)) {
446+ auto offsetInt = cast<IntegerAttr>(offsetAttr).getInt ();
447+ auto contribInt = cast<IntegerAttr>(contribAttr).getInt ();
448+ return rewriter.getIndexAttr (offsetInt + contribInt);
449+ }
450+ }
451+ Value offsetVal = getValueFromOpFoldResult (rewriter, loc, resultOffset);
452+ Value contribVal = getValueFromOpFoldResult (rewriter, loc, contribution);
453+ return rewriter.create <arith::AddIOp>(loc, offsetVal, contribVal)
454+ .getResult ();
455+ }();
456+
457+ if (droppedDims.test (idx))
458+ continue ;
459+
460+ resultSizes.push_back (sizeOfr);
461+ OpFoldResult combinedStride = [&]() -> OpFoldResult {
462+ if (Attribute relStrideAttr = dyn_cast<Attribute>(relativeStrideOfr)) {
463+ if (Attribute strideAttr = dyn_cast<Attribute>(strideOfr)) {
464+ auto relStrideInt = cast<IntegerAttr>(relStrideAttr).getInt ();
465+ auto strideInt = cast<IntegerAttr>(strideAttr).getInt ();
466+ return rewriter.getIndexAttr (relStrideInt * strideInt);
467+ }
468+ }
469+ Value relStrideVal =
470+ getValueFromOpFoldResult (rewriter, loc, relativeStrideOfr);
471+ Value strideVal = getValueFromOpFoldResult (rewriter, loc, strideOfr);
472+ return rewriter.create <arith::MulIOp>(loc, relStrideVal, strideVal)
473+ .getResult ();
474+ }();
475+ resultStrides.push_back (combinedStride);
476+ }
477+
478+ memref::LinearizedMemRefInfo linearizedInfo;
479+ [[maybe_unused]] OpFoldResult linearizedIndex;
480+ std::tie (linearizedInfo, linearizedIndex) =
481+ memref::getLinearizedMemRefOffsetAndSize (
482+ rewriter, loc, elementBitWidth, elementBitWidth, resultOffset,
483+ resultSizes, resultStrides);
484+
485+ Value flattenedSize = getValueFromOpFoldResult (
486+ rewriter, loc, linearizedInfo.linearizedSize );
487+ Value strideOne = arith::ConstantIndexOp::create (rewriter, loc, 1 );
488+
489+ Value flattenedSubview = memref::SubViewOp::create (
490+ rewriter, loc, flatSource, ValueRange{linearOffset},
491+ ValueRange{flattenedSize}, ValueRange{strideOne});
492+
493+ Value replacement = memref::ReinterpretCastOp::create (
494+ rewriter, loc, resultType, flattenedSubview, resultOffset, resultSizes,
495+ resultStrides);
496+
497+ rewriter.replaceOp (op, replacement);
399498 return success ();
400499 }
401500};
402- */
403501
404502struct FlattenMemrefsPass
405503 : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
@@ -422,18 +520,6 @@ struct FlattenMemrefsPass
422520
423521} // namespace
424522
425- void memref::populateFlattenVectorOpsOnMemrefPatterns (
426- RewritePatternSet &patterns) {
427- patterns.insert <MemRefRewritePattern<vector::LoadOp>,
428- MemRefRewritePattern<vector::StoreOp>,
429- MemRefRewritePattern<vector::TransferReadOp>,
430- MemRefRewritePattern<vector::TransferWriteOp>,
431- MemRefRewritePattern<vector::MaskedLoadOp>,
432- MemRefRewritePattern<vector::MaskedStoreOp>>(
433- patterns.getContext ());
434- }
435-
436- // / Special pattern for GetGlobalOp to avoid infinite loops
437523struct FlattenGetGlobal : public OpRewritePattern <memref::GetGlobalOp> {
438524 using OpRewritePattern::OpRewritePattern;
439525
@@ -470,6 +556,17 @@ struct FlattenGetGlobal : public OpRewritePattern<memref::GetGlobalOp> {
470556 }
471557};
472558
559+ void memref::populateFlattenVectorOpsOnMemrefPatterns (
560+ RewritePatternSet &patterns) {
561+ patterns.insert <MemRefRewritePattern<vector::LoadOp>,
562+ MemRefRewritePattern<vector::StoreOp>,
563+ MemRefRewritePattern<vector::TransferReadOp>,
564+ MemRefRewritePattern<vector::TransferWriteOp>,
565+ MemRefRewritePattern<vector::MaskedLoadOp>,
566+ MemRefRewritePattern<vector::MaskedStoreOp>>(
567+ patterns.getContext ());
568+ }
569+
473570void memref::populateFlattenMemrefOpsPatterns (RewritePatternSet &patterns) {
474571 patterns.insert <MemRefRewritePattern<memref::LoadOp>,
475572 MemRefRewritePattern<memref::StoreOp>,
@@ -478,7 +575,7 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
478575 MemRefRewritePattern<memref::DeallocOp>,
479576 FlattenExpandShape,
480577 FlattenCollapseShape,
481- // FlattenSubView,
578+ FlattenSubView,
482579 FlattenGetGlobal,
483580 FlattenGlobal>(
484581 patterns.getContext ());
0 commit comments