Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@
#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H

#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallVector.h"

namespace mlir {
class Location;
class OpBuilder;
class RewritePatternSet;
class RewriterBase;
Expand All @@ -33,7 +38,9 @@ class NarrowTypeEmulationConverter;
namespace memref {
class AllocOp;
class AllocaOp;
class CollapseShapeOp;
class DeallocOp;
class ExpandShapeOp;

//===----------------------------------------------------------------------===//
// Patterns
Expand Down Expand Up @@ -213,6 +220,100 @@ FailureOr<Value> replaceWithIndependentOp(RewriterBase &rewriter,
memref::AllocaOp allocToAlloca(
RewriterBase &rewriter, memref::AllocOp alloc,
function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);

/// Compute the expanded sizes of the given \p expandShape for the
/// \p groupId-th reassociation group.
/// \p origSizes hold the sizes of the source shape as values.
/// This is used to compute the new sizes in cases of dynamic shapes.
///
/// sizes#i =
/// baseSizes#groupId / product(expandShapeSizes#j,
/// for j in group excluding reassIdx#i)
/// Where reassIdx#i is the reassociation index at index i in \p groupId.
///
/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
///
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
/// this is not possible because this function uses the Affine dialect and the
/// MemRef dialect cannot depend on the Affine dialect.
SmallVector<OpFoldResult> getExpandedSizes(ExpandShapeOp expandShape,
OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes,
unsigned groupId);

/// Compute the expanded strides of the given \p expandShape for the
/// \p groupId-th reassociation group.
/// \p origStrides and \p origSizes hold respectively the strides and sizes
/// of the source shape as values.
/// This is used to compute the strides in cases of dynamic shapes and/or
/// dynamic stride for this reassociation group.
///
/// strides#i =
/// origStrides#reassDim * product(expandShapeSizes#j, for j in
/// reassIdx#i+1..reassIdx#i+group.size-1)
///
/// Where reassIdx#i is the reassociation index for at index i in \p groupId
/// and expandShapeSizes#j is either:
/// - The constant size at dimension j, derived directly from the result type of
/// the expand_shape op, or
/// - An affine expression: baseSizes#reassDim / product of all constant sizes
/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
/// element.)
///
/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
///
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
/// this is not possible because this function uses the Affine dialect and the
/// MemRef dialect cannot depend on the Affine dialect.
SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes,
ArrayRef<OpFoldResult> origStrides,
unsigned groupId);

/// Produce an OpFoldResult object with \p builder at \p loc representing
/// `prod(valueOrConstant#i, for i in {indices})`,
/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
/// values[i] otherwise.
///
/// \pre for all index in indices: index < values.size()
/// \pre for all index in indices: index < maybeConstants.size()
OpFoldResult getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder,
Location loc, ArrayRef<int64_t> maybeConstants,
ArrayRef<OpFoldResult> values,
llvm::function_ref<bool(int64_t)> isDynamic);

/// Compute the collapsed size of the given \p collapseShape for the
/// \p groupId-th reassociation group.
/// \p origSizes hold the sizes of the source shape as values.
/// This is used to compute the new sizes in cases of dynamic shapes.
///
/// TODO: Move this utility function directly within CollapseShapeOp. For now,
/// this is not possible because this function uses the Affine dialect and the
/// MemRef dialect cannot depend on the Affine dialect.
SmallVector<OpFoldResult> getCollapsedSize(CollapseShapeOp collapseShape,
OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes,
unsigned groupId);

/// Compute the collapsed stride of the given \p collpaseShape for the
Copy link

Copilot AI Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in parameter name: 'collpaseShape' should be 'collapseShape'.

Suggested change
/// Compute the collapsed stride of the given \p collpaseShape for the
/// Compute the collapsed stride of the given \p collapseShape for the

Copilot uses AI. Check for mistakes.
/// \p groupId-th reassociation group.
/// \p origStrides and \p origSizes hold respectively the strides and sizes
/// of the source shape as values.
/// This is used to compute the strides in cases of dynamic shapes and/or
/// dynamic stride for this reassociation group.
///
/// Conceptually this helper function returns the stride of the inner most
/// dimension of that group in the original shape.
///
/// \post result.size() == 1, in other words, each group collapse to one
/// dimension.
SmallVector<OpFoldResult> getCollapsedStride(CollapseShapeOp collapseShape,
OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes,
ArrayRef<OpFoldResult> origStrides,
unsigned groupId);

} // namespace memref
} // namespace mlir

Expand Down
90 changes: 28 additions & 62 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
Expand All @@ -35,6 +36,7 @@ namespace memref {

using namespace mlir;
using namespace mlir::affine;
using namespace mlir::memref;

namespace {

Expand Down Expand Up @@ -250,24 +252,14 @@ struct ExtractStridedMetadataOpSubviewFolder
}
};

/// Compute the expanded sizes of the given \p expandShape for the
/// \p groupId-th reassociation group.
/// \p origSizes hold the sizes of the source shape as values.
/// This is used to compute the new sizes in cases of dynamic shapes.
///
/// sizes#i =
/// baseSizes#groupId / product(expandShapeSizes#j,
/// for j in group excluding reassIdx#i)
/// Where reassIdx#i is the reassociation index at index i in \p groupId.
///
/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
///
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
/// this is not possible because this function uses the Affine dialect and the
/// MemRef dialect cannot depend on the Affine dialect.
static SmallVector<OpFoldResult>
getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
} // namespace

namespace mlir {
namespace memref {
SmallVector<OpFoldResult> getExpandedSizes(ExpandShapeOp expandShape,
OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes,
unsigned groupId) {
SmallVector<int64_t, 2> reassocGroup =
expandShape.getReassociationIndices()[groupId];
assert(!reassocGroup.empty() &&
Expand Down Expand Up @@ -305,31 +297,7 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
return expandedSizes;
}

/// Compute the expanded strides of the given \p expandShape for the
/// \p groupId-th reassociation group.
/// \p origStrides and \p origSizes hold respectively the strides and sizes
/// of the source shape as values.
/// This is used to compute the strides in cases of dynamic shapes and/or
/// dynamic stride for this reassociation group.
///
/// strides#i =
/// origStrides#reassDim * product(expandShapeSizes#j, for j in
/// reassIdx#i+1..reassIdx#i+group.size-1)
///
/// Where reassIdx#i is the reassociation index for at index i in \p groupId
/// and expandShapeSizes#j is either:
/// - The constant size at dimension j, derived directly from the result type of
/// the expand_shape op, or
/// - An affine expression: baseSizes#reassDim / product of all constant sizes
/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
/// element.)
///
/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
///
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
/// this is not possible because this function uses the Affine dialect and the
/// MemRef dialect cannot depend on the Affine dialect.
SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes,
ArrayRef<OpFoldResult> origStrides,
Expand Down Expand Up @@ -405,18 +373,10 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
return expandedStrides;
}

/// Produce an OpFoldResult object with \p builder at \p loc representing
/// `prod(valueOrConstant#i, for i in {indices})`,
/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
/// values[i] otherwise.
///
/// \pre for all index in indices: index < values.size()
/// \pre for all index in indices: index < maybeConstants.size()
static OpFoldResult
getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
ArrayRef<int64_t> maybeConstants,
ArrayRef<OpFoldResult> values,
llvm::function_ref<bool(int64_t)> isDynamic) {
OpFoldResult getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder,
Location loc, ArrayRef<int64_t> maybeConstants,
ArrayRef<OpFoldResult> values,
llvm::function_ref<bool(int64_t)> isDynamic) {
AffineExpr productOfValues = builder.getAffineConstantExpr(1);
SmallVector<OpFoldResult> inputValues;
unsigned numberOfSymbols = 0;
Expand Down Expand Up @@ -450,9 +410,10 @@ getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
/// TODO: Move this utility function directly within CollapseShapeOp. For now,
/// this is not possible because this function uses the Affine dialect and the
/// MemRef dialect cannot depend on the Affine dialect.
static SmallVector<OpFoldResult>
getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
SmallVector<OpFoldResult> getCollapsedSize(CollapseShapeOp collapseShape,
OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes,
unsigned groupId) {
SmallVector<OpFoldResult> collapsedSize;

MemRefType collapseShapeType = collapseShape.getResultType();
Expand Down Expand Up @@ -491,10 +452,11 @@ getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
///
/// \post result.size() == 1, in other words, each group collapse to one
/// dimension.
static SmallVector<OpFoldResult>
getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes,
ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
SmallVector<OpFoldResult> getCollapsedStride(CollapseShapeOp collapseShape,
OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes,
ArrayRef<OpFoldResult> origStrides,
unsigned groupId) {
SmallVector<int64_t, 2> reassocGroup =
collapseShape.getReassociationIndices()[groupId];
assert(!reassocGroup.empty() &&
Expand Down Expand Up @@ -546,6 +508,10 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,

return {lastValidStride};
}
} // namespace memref
} // namespace mlir

namespace {

/// From `reshape_like(memref, subSizes, subStrides))` compute
///
Expand Down
Loading