2121#include " mlir/Dialect/Utils/StaticValueUtils.h"
2222#include " mlir/Dialect/Vector/IR/VectorOps.h"
2323#include " mlir/IR/Attributes.h"
24- #include " mlir/IR/DialectResourceBlobManager.h"
2524#include " mlir/IR/Builders.h"
2625#include " mlir/IR/BuiltinTypes.h"
26+ #include " mlir/IR/DialectResourceBlobManager.h"
2727#include " mlir/IR/OpDefinition.h"
2828#include " mlir/IR/PatternMatch.h"
2929#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -48,7 +48,6 @@ static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
4848 return cast<Value>(in);
4949}
5050
51-
5251// / Returns a collapsed memref and the linearized index to access the element
5352// / at the specified indices.
5453static std::pair<Value, Value> getFlattenMemrefAndOffset (OpBuilder &rewriter,
@@ -281,9 +280,8 @@ struct FlattenGlobal final : public OpRewritePattern<memref::GlobalOp> {
281280 return {};
282281 }
283282
284- LogicalResult
285- matchAndRewrite (memref::GlobalOp globalOp,
286- PatternRewriter &rewriter) const override {
283+ LogicalResult matchAndRewrite (memref::GlobalOp globalOp,
284+ PatternRewriter &rewriter) const override {
287285 auto oldType = llvm::dyn_cast<MemRefType>(globalOp.getType ());
288286 if (!oldType || !oldType.getLayout ().isIdentity () || oldType.getRank () <= 1 )
289287 return failure ();
@@ -314,7 +312,8 @@ struct FlattenCollapseShape final
314312 memref::ExtractStridedMetadataOp::create (rewriter, loc, op.getSrc ());
315313
316314 SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes ();
317- SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides ();
315+ SmallVector<OpFoldResult> origStrides =
316+ metadata.getConstifiedMixedStrides ();
318317 OpFoldResult offset = metadata.getConstifiedMixedOffset ();
319318
320319 SmallVector<OpFoldResult> collapsedSizes;
@@ -338,7 +337,8 @@ struct FlattenCollapseShape final
338337 }
339338};
340339
341- struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp> {
340+ struct FlattenExpandShape final
341+ : public OpRewritePattern<memref::ExpandShapeOp> {
342342 using OpRewritePattern::OpRewritePattern;
343343
344344 LogicalResult matchAndRewrite (memref::ExpandShapeOp op,
@@ -348,7 +348,8 @@ struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp>
348348 memref::ExtractStridedMetadataOp::create (rewriter, loc, op.getSrc ());
349349
350350 SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes ();
351- SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides ();
351+ SmallVector<OpFoldResult> origStrides =
352+ metadata.getConstifiedMixedStrides ();
352353 OpFoldResult offset = metadata.getConstifiedMixedOffset ();
353354
354355 SmallVector<OpFoldResult> expandedSizes;
@@ -372,7 +373,6 @@ struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp>
372373 }
373374};
374375
375-
376376// Flattens memref subview ops with more than 1 dimension into 1-D accesses.
377377struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
378378 using OpRewritePattern::OpRewritePattern;
@@ -405,9 +405,8 @@ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
405405 for (OpFoldResult ofr : mixedOffsets)
406406 offsetValues.push_back (getValueFromOpFoldResult (rewriter, loc, ofr));
407407
408- auto [flatSource, linearOffset] =
409- getFlattenMemrefAndOffset (rewriter, loc, op.getSource (),
410- ValueRange (offsetValues));
408+ auto [flatSource, linearOffset] = getFlattenMemrefAndOffset (
409+ rewriter, loc, op.getSource (), ValueRange (offsetValues));
411410
412411 memref::ExtractStridedMetadataOp sourceMetadata =
413412 memref::ExtractStridedMetadataOp::create (rewriter, loc, op.getSource ());
@@ -424,9 +423,14 @@ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
424423 resultStrides.reserve (resultType.getRank ());
425424
426425 OpFoldResult resultOffset = sourceOffset;
427- for (auto [idx, it] : llvm::enumerate (llvm::zip_equal (
426+ for (auto zipped : llvm::enumerate (llvm::zip_equal (
428427 mixedOffsets, sourceStrides, mixedSizes, mixedStrides))) {
429- auto [offsetOfr, strideOfr, sizeOfr, relativeStrideOfr] = it;
428+ auto idx = zipped.index ();
429+ auto it = zipped.value ();
430+ auto offsetOfr = std::get<0 >(it);
431+ auto strideOfr = std::get<1 >(it);
432+ auto sizeOfr = std::get<2 >(it);
433+ auto relativeStrideOfr = std::get<3 >(it);
430434 OpFoldResult contribution = [&]() -> OpFoldResult {
431435 if (Attribute offsetAttr = dyn_cast<Attribute>(offsetOfr)) {
432436 if (Attribute strideAttr = dyn_cast<Attribute>(strideOfr)) {
@@ -449,7 +453,8 @@ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
449453 }
450454 }
451455 Value offsetVal = getValueFromOpFoldResult (rewriter, loc, resultOffset);
452- Value contribVal = getValueFromOpFoldResult (rewriter, loc, contribution);
456+ Value contribVal =
457+ getValueFromOpFoldResult (rewriter, loc, contribution);
453458 return rewriter.create <arith::AddIOp>(loc, offsetVal, contribVal)
454459 .getResult ();
455460 }();
@@ -478,12 +483,12 @@ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
478483 memref::LinearizedMemRefInfo linearizedInfo;
479484 [[maybe_unused]] OpFoldResult linearizedIndex;
480485 std::tie (linearizedInfo, linearizedIndex) =
481- memref::getLinearizedMemRefOffsetAndSize (
482- rewriter, loc, elementBitWidth, elementBitWidth, resultOffset,
483- resultSizes, resultStrides);
486+ memref::getLinearizedMemRefOffsetAndSize (rewriter, loc, elementBitWidth,
487+ elementBitWidth, resultOffset,
488+ resultSizes, resultStrides);
484489
485- Value flattenedSize = getValueFromOpFoldResult (
486- rewriter, loc, linearizedInfo.linearizedSize );
490+ Value flattenedSize =
491+ getValueFromOpFoldResult ( rewriter, loc, linearizedInfo.linearizedSize );
487492 Value strideOne = arith::ConstantIndexOp::create (rewriter, loc, 1 );
488493
489494 Value flattenedSubview = memref::SubViewOp::create (
@@ -524,10 +529,11 @@ struct FlattenGetGlobal : public OpRewritePattern<memref::GetGlobalOp> {
524529 using OpRewritePattern::OpRewritePattern;
525530
526531 LogicalResult matchAndRewrite (memref::GetGlobalOp op,
527- PatternRewriter &rewriter) const override {
532+ PatternRewriter &rewriter) const override {
528533 // Check if this get_global references a multi-dimensional global
529534 auto module = op->template getParentOfType <ModuleOp>();
530- auto globalOp = module .template lookupSymbol <memref::GlobalOp>(op.getName ());
535+ auto globalOp =
536+ module .template lookupSymbol <memref::GlobalOp>(op.getName ());
531537 if (!globalOp) {
532538 return failure ();
533539 }
@@ -537,12 +543,13 @@ struct FlattenGetGlobal : public OpRewritePattern<memref::GetGlobalOp> {
537543
538544 // Only apply if the global has been flattened but the get_global hasn't
539545 if (globalType.getRank () == 1 && resultType.getRank () > 1 ) {
540- auto newGetGlobal = memref::GetGlobalOp::create (
541- rewriter, op. getLoc (), globalType, op.getName ());
546+ auto newGetGlobal = memref::GetGlobalOp::create (rewriter, op. getLoc (),
547+ globalType, op.getName ());
542548
543549 // Cast the flattened result back to the original shape
544550 memref::ExtractStridedMetadataOp stridedMetadata =
545- memref::ExtractStridedMetadataOp::create (rewriter, op.getLoc (), op.getResult ());
551+ memref::ExtractStridedMetadataOp::create (rewriter, op.getLoc (),
552+ op.getResult ());
546553 auto castResult = memref::ReinterpretCastOp::create (
547554 rewriter, op.getLoc (), resultType, newGetGlobal,
548555 /* offset=*/ rewriter.getIndexAttr (0 ),
@@ -572,13 +579,9 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
572579 MemRefRewritePattern<memref::StoreOp>,
573580 MemRefRewritePattern<memref::AllocOp>,
574581 MemRefRewritePattern<memref::AllocaOp>,
575- MemRefRewritePattern<memref::DeallocOp>,
576- FlattenExpandShape,
577- FlattenCollapseShape,
578- FlattenSubView,
579- FlattenGetGlobal,
580- FlattenGlobal>(
581- patterns.getContext ());
582+ MemRefRewritePattern<memref::DeallocOp>, FlattenExpandShape,
583+ FlattenCollapseShape, FlattenSubView, FlattenGetGlobal,
584+ FlattenGlobal>(patterns.getContext ());
582585}
583586
584587void memref::populateFlattenMemrefsPatterns (RewritePatternSet &patterns) {
0 commit comments