-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR] Add more ops support for flattening memref operands #159841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-memref Author: Alan Li (lialan) ChangesThis patch is to make the Some patterns are from: https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp Patch is 34.85 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159841.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 8b76930aed35a..562b8c11225e8 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -14,10 +14,15 @@
#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
+#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/STLFunctionalExtras.h"
namespace mlir {
+class Location;
class OpBuilder;
class RewritePatternSet;
class RewriterBase;
@@ -33,7 +38,9 @@ class NarrowTypeEmulationConverter;
namespace memref {
class AllocOp;
class AllocaOp;
+class CollapseShapeOp;
class DeallocOp;
+class ExpandShapeOp;
//===----------------------------------------------------------------------===//
// Patterns
@@ -213,6 +220,98 @@ FailureOr<Value> replaceWithIndependentOp(RewriterBase &rewriter,
memref::AllocaOp allocToAlloca(
RewriterBase &rewriter, memref::AllocOp alloc,
function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);
+
+/// Compute the expanded sizes of the given \p expandShape for the
+/// \p groupId-th reassociation group.
+/// \p origSizes hold the sizes of the source shape as values.
+/// This is used to compute the new sizes in cases of dynamic shapes.
+///
+/// sizes#i =
+/// baseSizes#groupId / product(expandShapeSizes#j,
+/// for j in group excluding reassIdx#i)
+/// Where reassIdx#i is the reassociation index at index i in \p groupId.
+///
+/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
+///
+/// TODO: Move this utility function directly within ExpandShapeOp. For now,
+/// this is not possible because this function uses the Affine dialect and the
+/// MemRef dialect cannot depend on the Affine dialect.
+SmallVector<OpFoldResult>
+getExpandedSizes(ExpandShapeOp expandShape, OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes, unsigned groupId);
+
+/// Compute the expanded strides of the given \p expandShape for the
+/// \p groupId-th reassociation group.
+/// \p origStrides and \p origSizes hold respectively the strides and sizes
+/// of the source shape as values.
+/// This is used to compute the strides in cases of dynamic shapes and/or
+/// dynamic stride for this reassociation group.
+///
+/// strides#i =
+/// origStrides#reassDim * product(expandShapeSizes#j, for j in
+/// reassIdx#i+1..reassIdx#i+group.size-1)
+///
+/// Where reassIdx#i is the reassociation index for at index i in \p groupId
+/// and expandShapeSizes#j is either:
+/// - The constant size at dimension j, derived directly from the result type of
+/// the expand_shape op, or
+/// - An affine expression: baseSizes#reassDim / product of all constant sizes
+/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
+/// element.)
+///
+/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
+///
+/// TODO: Move this utility function directly within ExpandShapeOp. For now,
+/// this is not possible because this function uses the Affine dialect and the
+/// MemRef dialect cannot depend on the Affine dialect.
+SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
+ OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes,
+ ArrayRef<OpFoldResult> origStrides,
+ unsigned groupId);
+
+/// Produce an OpFoldResult object with \p builder at \p loc representing
+/// `prod(valueOrConstant#i, for i in {indices})`,
+/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
+/// values[i] otherwise.
+///
+/// \pre for all index in indices: index < values.size()
+/// \pre for all index in indices: index < maybeConstants.size()
+OpFoldResult
+getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
+ ArrayRef<int64_t> maybeConstants,
+ ArrayRef<OpFoldResult> values,
+ llvm::function_ref<bool(int64_t)> isDynamic);
+
+/// Compute the collapsed size of the given \p collapseShape for the
+/// \p groupId-th reassociation group.
+/// \p origSizes hold the sizes of the source shape as values.
+/// This is used to compute the new sizes in cases of dynamic shapes.
+///
+/// TODO: Move this utility function directly within CollapseShapeOp. For now,
+/// this is not possible because this function uses the Affine dialect and the
+/// MemRef dialect cannot depend on the Affine dialect.
+SmallVector<OpFoldResult>
+getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes, unsigned groupId);
+
+/// Compute the collapsed stride of the given \p collpaseShape for the
+/// \p groupId-th reassociation group.
+/// \p origStrides and \p origSizes hold respectively the strides and sizes
+/// of the source shape as values.
+/// This is used to compute the strides in cases of dynamic shapes and/or
+/// dynamic stride for this reassociation group.
+///
+/// Conceptually this helper function returns the stride of the inner most
+/// dimension of that group in the original shape.
+///
+/// \post result.size() == 1, in other words, each group collapse to one
+/// dimension.
+SmallVector<OpFoldResult>
+getCollapsedStride(CollapseShapeOp collapseShape, OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes,
+ ArrayRef<OpFoldResult> origStrides, unsigned groupId);
+
} // namespace memref
} // namespace mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index d35566a9c0d29..6b69d0e366903 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -24,6 +24,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
+#include "mlir/IR/OpDefinition.h"
#include <optional>
namespace mlir {
@@ -35,6 +36,7 @@ namespace memref {
using namespace mlir;
using namespace mlir::affine;
+using namespace mlir::memref;
namespace {
@@ -250,23 +252,12 @@ struct ExtractStridedMetadataOpSubviewFolder
}
};
-/// Compute the expanded sizes of the given \p expandShape for the
-/// \p groupId-th reassociation group.
-/// \p origSizes hold the sizes of the source shape as values.
-/// This is used to compute the new sizes in cases of dynamic shapes.
-///
-/// sizes#i =
-/// baseSizes#groupId / product(expandShapeSizes#j,
-/// for j in group excluding reassIdx#i)
-/// Where reassIdx#i is the reassociation index at index i in \p groupId.
-///
-/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
-///
-/// TODO: Move this utility function directly within ExpandShapeOp. For now,
-/// this is not possible because this function uses the Affine dialect and the
-/// MemRef dialect cannot depend on the Affine dialect.
-static SmallVector<OpFoldResult>
-getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
+} // namespace
+
+namespace mlir {
+namespace memref {
+SmallVector<OpFoldResult>
+getExpandedSizes(ExpandShapeOp expandShape, OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
SmallVector<int64_t, 2> reassocGroup =
expandShape.getReassociationIndices()[groupId];
@@ -305,31 +296,7 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
return expandedSizes;
}
-/// Compute the expanded strides of the given \p expandShape for the
-/// \p groupId-th reassociation group.
-/// \p origStrides and \p origSizes hold respectively the strides and sizes
-/// of the source shape as values.
-/// This is used to compute the strides in cases of dynamic shapes and/or
-/// dynamic stride for this reassociation group.
-///
-/// strides#i =
-/// origStrides#reassDim * product(expandShapeSizes#j, for j in
-/// reassIdx#i+1..reassIdx#i+group.size-1)
-///
-/// Where reassIdx#i is the reassociation index for at index i in \p groupId
-/// and expandShapeSizes#j is either:
-/// - The constant size at dimension j, derived directly from the result type of
-/// the expand_shape op, or
-/// - An affine expression: baseSizes#reassDim / product of all constant sizes
-/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
-/// element.)
-///
-/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
-///
-/// TODO: Move this utility function directly within ExpandShapeOp. For now,
-/// this is not possible because this function uses the Affine dialect and the
-/// MemRef dialect cannot depend on the Affine dialect.
-SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
+SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes,
ArrayRef<OpFoldResult> origStrides,
@@ -405,14 +372,7 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
return expandedStrides;
}
-/// Produce an OpFoldResult object with \p builder at \p loc representing
-/// `prod(valueOrConstant#i, for i in {indices})`,
-/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
-/// values[i] otherwise.
-///
-/// \pre for all index in indices: index < values.size()
-/// \pre for all index in indices: index < maybeConstants.size()
-static OpFoldResult
+OpFoldResult
getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
ArrayRef<int64_t> maybeConstants,
ArrayRef<OpFoldResult> values,
@@ -450,8 +410,8 @@ getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
/// TODO: Move this utility function directly within CollapseShapeOp. For now,
/// this is not possible because this function uses the Affine dialect and the
/// MemRef dialect cannot depend on the Affine dialect.
-static SmallVector<OpFoldResult>
-getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
+SmallVector<OpFoldResult>
+getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
SmallVector<OpFoldResult> collapsedSize;
@@ -491,8 +451,8 @@ getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
///
/// \post result.size() == 1, in other words, each group collapse to one
/// dimension.
-static SmallVector<OpFoldResult>
-getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
+SmallVector<OpFoldResult>
+getCollapsedStride(CollapseShapeOp collapseShape, OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes,
ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
SmallVector<int64_t, 2> reassocGroup =
@@ -546,6 +506,10 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
return {lastValidStride};
}
+} // namespace memref
+} // namespace mlir
+
+namespace {
/// From `reshape_like(memref, subSizes, subStrides))` compute
///
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 1208fddf37e0b..43a67f1fab2be 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -21,11 +21,13 @@
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
@@ -46,6 +48,7 @@ static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
return cast<Value>(in);
}
+
/// Returns a collapsed memref and the linearized index to access the element
/// at the specified indices.
static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
@@ -89,17 +92,21 @@ static bool needFlattening(Value val) {
return type.getRank() > 1;
}
-static bool checkLayout(Value val) {
- auto type = cast<MemRefType>(val.getType());
+static bool checkLayout(MemRefType type) {
return type.getLayout().isIdentity() ||
isa<StridedLayoutAttr>(type.getLayout());
}
+static bool checkLayout(Value val) {
+ return checkLayout(cast<MemRefType>(val.getType()));
+}
+
namespace {
static Value getTargetMemref(Operation *op) {
return llvm::TypeSwitch<Operation *, Value>(op)
.template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp,
- memref::AllocOp>([](auto op) { return op.getMemref(); })
+ memref::AllocOp, memref::DeallocOp>(
+ [](auto op) { return op.getMemref(); })
.template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
vector::MaskedStoreOp, vector::TransferReadOp,
vector::TransferWriteOp>(
@@ -189,6 +196,10 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
rewriter, loc, op.getVector(), flatMemref, ValueRange{offset});
rewriter.replaceOp(op, newTransferWrite);
})
+ .template Case<memref::DeallocOp>([&](auto op) {
+ auto newDealloc = memref::DeallocOp::create(rewriter, loc, flatMemref);
+ rewriter.replaceOp(op, newDealloc);
+ })
.Default([&](auto op) {
op->emitOpError("unimplemented: do not know how to replace op.");
});
@@ -197,7 +208,8 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
template <typename T>
static ValueRange getIndices(T op) {
if constexpr (std::is_same_v<T, memref::AllocaOp> ||
- std::is_same_v<T, memref::AllocOp>) {
+ std::is_same_v<T, memref::AllocOp> ||
+ std::is_same_v<T, memref::DeallocOp>) {
return ValueRange{};
} else {
return op.getIndices();
@@ -250,6 +262,243 @@ struct MemRefRewritePattern : public OpRewritePattern<T> {
}
};
+/// Flattens memref global ops with more than 1 dimensions to 1 dimension.
+struct FlattenGlobal final : public OpRewritePattern<memref::GlobalOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ static Attribute flattenAttribute(Attribute value, ShapedType newType) {
+ if (!value)
+ return value;
+ if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(value)) {
+ return splatAttr.reshape(newType);
+ } else if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(value)) {
+ return denseAttr.reshape(newType);
+ } else if (auto denseResourceAttr =
+ llvm::dyn_cast<DenseResourceElementsAttr>(value)) {
+ return DenseResourceElementsAttr::get(newType,
+ denseResourceAttr.getRawHandle());
+ }
+ return {};
+ }
+
+ LogicalResult
+ matchAndRewrite(memref::GlobalOp globalOp,
+ PatternRewriter &rewriter) const override {
+ auto oldType = llvm::dyn_cast<MemRefType>(globalOp.getType());
+ if (!oldType || !oldType.getLayout().isIdentity() || oldType.getRank() <= 1)
+ return failure();
+
+ auto tensorType = RankedTensorType::get({oldType.getNumElements()},
+ oldType.getElementType());
+ auto memRefType =
+ MemRefType::get({oldType.getNumElements()}, oldType.getElementType(),
+ AffineMap(), oldType.getMemorySpace());
+ auto newInitialValue =
+ flattenAttribute(globalOp.getInitialValueAttr(), tensorType);
+ rewriter.replaceOpWithNewOp<memref::GlobalOp>(
+ globalOp, globalOp.getSymName(), globalOp.getSymVisibilityAttr(),
+ memRefType, newInitialValue, globalOp.getConstant(),
+ /*alignment=*/IntegerAttr());
+ return success();
+ }
+};
+
+struct FlattenCollapseShape final
+ : public OpRewritePattern<memref::CollapseShapeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::CollapseShapeOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ memref::ExtractStridedMetadataOp metadata =
+ memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
+
+ SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
+ SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides();
+ OpFoldResult offset = metadata.getConstifiedMixedOffset();
+
+ SmallVector<OpFoldResult> collapsedSizes;
+ SmallVector<OpFoldResult> collapsedStrides;
+ unsigned numGroups = op.getReassociationIndices().size();
+ collapsedSizes.reserve(numGroups);
+ collapsedStrides.reserve(numGroups);
+ for (unsigned i = 0; i < numGroups; ++i) {
+ SmallVector<OpFoldResult> groupSizes =
+ memref::getCollapsedSize(op, rewriter, origSizes, i);
+ SmallVector<OpFoldResult> groupStrides =
+ memref::getCollapsedStride(op, rewriter, origSizes, origStrides, i);
+ collapsedSizes.append(groupSizes.begin(), groupSizes.end());
+ collapsedStrides.append(groupStrides.begin(), groupStrides.end());
+ }
+
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ op, op.getType(), op.getSrc(), offset, collapsedSizes,
+ collapsedStrides);
+ return success();
+ }
+};
+
+struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::ExpandShapeOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ memref::ExtractStridedMetadataOp metadata =
+ memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
+
+ SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
+ SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides();
+ OpFoldResult offset = metadata.getConstifiedMixedOffset();
+
+ SmallVector<OpFoldResult> expandedSizes;
+ SmallVector<OpFoldResult> expandedStrides;
+ unsigned numGroups = op.getReassociationIndices().size();
+ expandedSizes.reserve(op.getResultType().getRank());
+ expandedStrides.reserve(op.getResultType().getRank());
+
+ for (unsigned i = 0; i < numGroups; ++i) {
+ SmallVector<OpFoldResult> groupSizes =
+ memref::getExpandedSizes(op, rewriter, origSizes, i);
+ SmallVector<OpFoldResult> groupStrides =
+ memref::getExpandedStrides(op, rewriter, origSizes, origStrides, i);
+ expandedSizes.append(groupSizes.begin(), groupSizes.end());
+ expandedStrides.append(groupStrides.begin(), groupStrides.end());
+ }
+
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ op, op.getType(), op.getSrc(), offset, expandedSizes, expandedStrides);
+ return success();
+ }
+};
+
+
+// Flattens memref subview ops with more than 1 dimension into 1-D accesses.
+struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::SubViewOp op,
+ PatternRewriter &rewriter) const override {
+ auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
+ if (!sourceType || sourceType.getRank() <= 1)
+ return failure();
+ if (!checkLayout(sourceType))
+ return failure();
+
+ MemRefType resultType = op.getType();
+ if (resultType.getRank() <= 1 || !checkLayout(res...
[truncated]
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR extends the flatten-memref pass in MLIR to support additional operations by adding flattening patterns for more memref operations. The pass transforms multi-dimensional memref operations into equivalent 1D operations for optimization purposes.
Key changes:
- Added support for
memref.dealloc,memref.subview,memref.collapse_shape,memref.expand_shape, andmemref.globaloperations - Moved utility functions from ExpandStridedMetadata.cpp to the public API to enable reuse
- Added comprehensive test coverage for all new operation patterns
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp |
Implemented new flattening patterns for dealloc, subview, collapse/expand shape, and global operations |
mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp |
Moved utility functions to public API and updated namespace declarations |
mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h |
Added public declarations for moved utility functions |
mlir/test/Dialect/MemRef/flatten_memref.mlir |
Added comprehensive test cases for all new flattening patterns |
| rewriter.replaceOpWithNewOp<memref::GlobalOp>( | ||
| globalOp, globalOp.getSymName(), globalOp.getSymVisibilityAttr(), | ||
| memRefType, newInitialValue, globalOp.getConstant(), | ||
| /*alignment=*/IntegerAttr()); |
Copilot
AI
Sep 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Passing an empty IntegerAttr() for alignment may not preserve the original global's alignment attribute. Consider preserving the original alignment with globalOp.getAlignment() instead.
| /*alignment=*/IntegerAttr()); | |
| /*alignment=*/globalOp.getAlignment()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bot's got a point
| ArrayRef<OpFoldResult> origSizes, | ||
| unsigned groupId); | ||
|
|
||
| /// Compute the collapsed stride of the given \p collpaseShape for the |
Copilot
AI
Sep 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo in parameter name: 'collpaseShape' should be 'collapseShape'.
| /// Compute the collapsed stride of the given \p collpaseShape for the | |
| /// Compute the collapsed stride of the given \p collapseShape for the |
krzysz00
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm putting a hold on the expand_shape/collapse_shape/subview parts of this code until we've got a clearer understanding of why it's like that and why it's justified (and how it's "flattening")
| rewriter.replaceOpWithNewOp<memref::GlobalOp>( | ||
| globalOp, globalOp.getSymName(), globalOp.getSymVisibilityAttr(), | ||
| memRefType, newInitialValue, globalOp.getConstant(), | ||
| /*alignment=*/IntegerAttr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bot's got a point
| rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>( | ||
| op, op.getType(), op.getSrc(), offset, collapsedSizes, | ||
| collapsedStrides); | ||
| return success(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are you rewritting collapse_shape to a reinterpret_cast? Shouldn't it be a noop after flattening?
| resultStrides); | ||
|
|
||
| rewriter.replaceOp(op, replacement); | ||
| return success(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... Something here feels extremely dubious. Maybe it'll make more sense when I look at the tests
I'd say that this all should be run after fold-memref-alias-ops and so expand_shape, collapse_shape, and subview shouldn't exist at the time this code is run.
|
Also, general notes
|
This patch is to make the
flatten-memrefpass to be more complete by adding supports of memref operands of more ops.Some patterns are from: https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp