Skip to content

Commit f549e4f

Browse files
committed
New things.
1 parent bc742d1 commit f549e4f

File tree

2 files changed

+259
-1
lines changed

2 files changed

+259
-1
lines changed

mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp

Lines changed: 211 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2222
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2323
#include "mlir/IR/Attributes.h"
24+
#include "mlir/IR/DialectResourceBlobManager.h"
2425
#include "mlir/IR/Builders.h"
2526
#include "mlir/IR/BuiltinTypes.h"
2627
#include "mlir/IR/OpDefinition.h"
@@ -95,6 +96,99 @@ static bool checkLayout(Value val) {
9596
isa<StridedLayoutAttr>(type.getLayout());
9697
}
9798

99+
/// Produce an OpFoldResult representing the product of the values or constants
100+
/// referenced by `indices`. `staticShape` provides the statically known sizes
101+
/// for the source memref, while `values` contains the mixed (value/attribute)
102+
/// representation produced by `memref.extract_strided_metadata`.
103+
static OpFoldResult getProductOfValues(ArrayRef<int64_t> indices,
104+
OpBuilder &builder, Location loc,
105+
ArrayRef<int64_t> staticShape,
106+
ArrayRef<OpFoldResult> values) {
107+
AffineExpr product = builder.getAffineConstantExpr(1);
108+
SmallVector<OpFoldResult> inputs;
109+
unsigned numSymbols = 0;
110+
for (int64_t idx : indices) {
111+
product = product * builder.getAffineSymbolExpr(numSymbols++);
112+
if (ShapedType::isDynamic(staticShape[idx]))
113+
inputs.push_back(values[idx]);
114+
else
115+
inputs.push_back(builder.getIndexAttr(staticShape[idx]));
116+
}
117+
return affine::makeComposedFoldedAffineApply(builder, loc, product, inputs);
118+
}
119+
120+
/// Return the collapsed size (as OpFoldResult) for the reassociation group
121+
/// `groupId` of `collapseShapeOp`.
122+
static SmallVector<OpFoldResult>
123+
getCollapsedSize(memref::CollapseShapeOp collapseShapeOp, OpBuilder &builder,
124+
ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
125+
SmallVector<OpFoldResult> collapsedSize;
126+
127+
MemRefType resultType = collapseShapeOp.getResultType();
128+
int64_t dimSize = resultType.getDimSize(groupId);
129+
if (!ShapedType::isDynamic(dimSize)) {
130+
collapsedSize.push_back(builder.getIndexAttr(dimSize));
131+
return collapsedSize;
132+
}
133+
134+
auto sourceType = collapseShapeOp.getSrcType();
135+
ArrayRef<int64_t> staticShape = sourceType.getShape();
136+
ArrayRef<int64_t> reassocGroup =
137+
collapseShapeOp.getReassociationIndices()[groupId];
138+
139+
collapsedSize.push_back(getProductOfValues(reassocGroup, builder,
140+
collapseShapeOp.getLoc(),
141+
staticShape, origSizes));
142+
return collapsedSize;
143+
}
144+
145+
/// Return the collapsed stride (as OpFoldResult) for the reassociation group
146+
/// `groupId` of `collapseShapeOp`.
147+
static SmallVector<OpFoldResult> getCollapsedStride(
148+
memref::CollapseShapeOp collapseShapeOp, OpBuilder &builder,
149+
ArrayRef<OpFoldResult> origSizes, ArrayRef<OpFoldResult> origStrides,
150+
unsigned groupId) {
151+
ArrayRef<int64_t> reassocGroup =
152+
collapseShapeOp.getReassociationIndices()[groupId];
153+
assert(!reassocGroup.empty() &&
154+
"reassociation group must contain at least one dimension");
155+
156+
auto sourceType = collapseShapeOp.getSrcType();
157+
auto [strides, offset] = sourceType.getStridesAndOffset();
158+
(void)offset;
159+
ArrayRef<int64_t> srcShape = sourceType.getShape();
160+
161+
OpFoldResult lastValidStride = nullptr;
162+
for (int64_t dim : reassocGroup) {
163+
if (srcShape[dim] == 1)
164+
continue;
165+
int64_t currentStride = strides[dim];
166+
if (ShapedType::isDynamic(currentStride))
167+
lastValidStride = origStrides[dim];
168+
else
169+
lastValidStride = builder.getIndexAttr(currentStride);
170+
}
171+
172+
if (!lastValidStride) {
173+
MemRefType collapsedType = collapseShapeOp.getResultType();
174+
auto [collapsedStrides, collapsedOffset] =
175+
collapsedType.getStridesAndOffset();
176+
(void)collapsedOffset;
177+
int64_t finalStride = collapsedStrides[groupId];
178+
if (ShapedType::isDynamic(finalStride)) {
179+
for (int64_t dim : reassocGroup) {
180+
assert(srcShape[dim] == 1 && "expected size-one dimensions");
181+
if (ShapedType::isDynamic(strides[dim]))
182+
return {origStrides[dim]};
183+
}
184+
llvm_unreachable("expected to find a dynamic stride");
185+
}
186+
return {builder.getIndexAttr(finalStride)};
187+
}
188+
189+
return {lastValidStride};
190+
}
191+
98192
namespace {
99193
static Value getTargetMemref(Operation *op) {
100194
return llvm::TypeSwitch<Operation *, Value>(op)
@@ -256,6 +350,82 @@ struct MemRefRewritePattern : public OpRewritePattern<T> {
256350
}
257351
};
258352

353+
/// Flattens memref global ops with more than 1 dimensions to 1 dimension.
354+
struct FlattenGlobal final : public OpRewritePattern<memref::GlobalOp> {
355+
using OpRewritePattern::OpRewritePattern;
356+
357+
static Attribute flattenAttribute(Attribute value, ShapedType newType) {
358+
if (!value)
359+
return value;
360+
if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(value)) {
361+
return splatAttr.reshape(newType);
362+
} else if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(value)) {
363+
return denseAttr.reshape(newType);
364+
} else if (auto denseResourceAttr =
365+
llvm::dyn_cast<DenseResourceElementsAttr>(value)) {
366+
return DenseResourceElementsAttr::get(newType,
367+
denseResourceAttr.getRawHandle());
368+
}
369+
return {};
370+
}
371+
372+
LogicalResult
373+
matchAndRewrite(memref::GlobalOp globalOp,
374+
PatternRewriter &rewriter) const override {
375+
auto oldType = llvm::dyn_cast<MemRefType>(globalOp.getType());
376+
if (!oldType || !oldType.getLayout().isIdentity() || oldType.getRank() <= 1)
377+
return failure();
378+
379+
auto tensorType = RankedTensorType::get({oldType.getNumElements()},
380+
oldType.getElementType());
381+
auto memRefType =
382+
MemRefType::get({oldType.getNumElements()}, oldType.getElementType(),
383+
AffineMap(), oldType.getMemorySpace());
384+
auto newInitialValue =
385+
flattenAttribute(globalOp.getInitialValueAttr(), tensorType);
386+
rewriter.replaceOpWithNewOp<memref::GlobalOp>(
387+
globalOp, globalOp.getSymName(), globalOp.getSymVisibilityAttr(),
388+
memRefType, newInitialValue, globalOp.getConstant(),
389+
/*alignment=*/IntegerAttr());
390+
return success();
391+
}
392+
};
393+
394+
struct FlattenCollapseShape final
395+
: public OpRewritePattern<memref::CollapseShapeOp> {
396+
using OpRewritePattern::OpRewritePattern;
397+
398+
LogicalResult matchAndRewrite(memref::CollapseShapeOp op,
399+
PatternRewriter &rewriter) const override {
400+
Location loc = op.getLoc();
401+
memref::ExtractStridedMetadataOp metadata =
402+
memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
403+
404+
SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
405+
SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides();
406+
OpFoldResult offset = metadata.getConstifiedMixedOffset();
407+
408+
SmallVector<OpFoldResult> collapsedSizes;
409+
SmallVector<OpFoldResult> collapsedStrides;
410+
unsigned numGroups = op.getReassociationIndices().size();
411+
collapsedSizes.reserve(numGroups);
412+
collapsedStrides.reserve(numGroups);
413+
for (unsigned i = 0; i < numGroups; ++i) {
414+
SmallVector<OpFoldResult> groupSizes =
415+
getCollapsedSize(op, rewriter, origSizes, i);
416+
SmallVector<OpFoldResult> groupStrides =
417+
getCollapsedStride(op, rewriter, origSizes, origStrides, i);
418+
collapsedSizes.append(groupSizes.begin(), groupSizes.end());
419+
collapsedStrides.append(groupStrides.begin(), groupStrides.end());
420+
}
421+
422+
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
423+
op, op.getType(), op.getSrc(), offset, collapsedSizes,
424+
collapsedStrides);
425+
return success();
426+
}
427+
};
428+
259429
struct FlattenMemrefsPass
260430
: public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
261431
using Base::Base;
@@ -288,12 +458,52 @@ void memref::populateFlattenVectorOpsOnMemrefPatterns(
288458
patterns.getContext());
289459
}
290460

461+
/// Special pattern for GetGlobalOp to avoid infinite loops
462+
struct FlattenGetGlobal : public OpRewritePattern<memref::GetGlobalOp> {
463+
using OpRewritePattern::OpRewritePattern;
464+
465+
LogicalResult matchAndRewrite(memref::GetGlobalOp op,
466+
PatternRewriter &rewriter) const override {
467+
// Check if this get_global references a multi-dimensional global
468+
auto module = op->template getParentOfType<ModuleOp>();
469+
auto globalOp = module.template lookupSymbol<memref::GlobalOp>(op.getName());
470+
if (!globalOp) {
471+
return failure();
472+
}
473+
474+
auto globalType = globalOp.getType();
475+
auto resultType = op.getType();
476+
477+
// Only apply if the global has been flattened but the get_global hasn't
478+
if (globalType.getRank() == 1 && resultType.getRank() > 1) {
479+
auto newGetGlobal = memref::GetGlobalOp::create(
480+
rewriter, op.getLoc(), globalType, op.getName());
481+
482+
// Cast the flattened result back to the original shape
483+
memref::ExtractStridedMetadataOp stridedMetadata =
484+
memref::ExtractStridedMetadataOp::create(rewriter, op.getLoc(), op.getResult());
485+
auto castResult = memref::ReinterpretCastOp::create(
486+
rewriter, op.getLoc(), resultType, newGetGlobal,
487+
/*offset=*/rewriter.getIndexAttr(0),
488+
stridedMetadata.getConstifiedMixedSizes(),
489+
stridedMetadata.getConstifiedMixedStrides());
490+
rewriter.replaceOp(op, castResult);
491+
return success();
492+
}
493+
494+
return failure();
495+
}
496+
};
497+
291498
void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
292499
patterns.insert<MemRefRewritePattern<memref::LoadOp>,
293500
MemRefRewritePattern<memref::StoreOp>,
294501
MemRefRewritePattern<memref::AllocOp>,
295502
MemRefRewritePattern<memref::AllocaOp>,
296-
MemRefRewritePattern<memref::DeallocOp>>(
503+
MemRefRewritePattern<memref::DeallocOp>,
504+
FlattenCollapseShape,
505+
FlattenGetGlobal,
506+
FlattenGlobal>(
297507
patterns.getContext());
298508
}
299509

mlir/test/Dialect/MemRef/flatten_memref.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,35 @@ func.func @mask_load_vector_from_memref_dynamic(%input: memref<3x7xi2>, %row: in
194194

195195
// -----
196196

197+
func.func @collapse_shape_static(%arg0: memref<2x3x4xf32>) -> memref<6x4xf32> {
198+
%0 = memref.collapse_shape %arg0 [[0, 1], [2]]
199+
: memref<2x3x4xf32> into memref<6x4xf32>
200+
return %0 : memref<6x4xf32>
201+
}
202+
// CHECK-LABEL: func @collapse_shape_static
203+
// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [6, 4], strides: [4, 1]
204+
// CHECK: return %[[REINT]]
205+
206+
// -----
207+
208+
func.func @collapse_shape_dynamic(
209+
%arg0: memref<2x?x4xf32, strided<[?, ?, ?], offset: ?>>) ->
210+
memref<?x4xf32, strided<[?, ?], offset: ?>> {
211+
%0 = memref.collapse_shape %arg0 [[0, 1], [2]]
212+
: memref<2x?x4xf32, strided<[?, ?, ?], offset: ?>>
213+
into memref<?x4xf32, strided<[?, ?], offset: ?>>
214+
return %0 : memref<?x4xf32, strided<[?, ?], offset: ?>>
215+
}
216+
// CHECK: #map = affine_map<()[s0] -> (s0 * 2)>
217+
// CHECK: #map1 = affine_map<()[s0, s1] -> (s0 * 8 + s1)>
218+
// CHECK-LABEL: func @collapse_shape_dynamic
219+
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %arg0
220+
// CHECK: %[[SIZE:.*]] = affine.apply #map()[%[[SIZES]]#1]
221+
// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [%[[OFFSET]]], sizes: [%[[SIZE]], 4], strides: [%[[STRIDES]]#1, %[[STRIDES]]#2]
222+
// CHECK: return %[[REINT]]
223+
224+
// -----
225+
197226
func.func @transfer_read_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) -> vector<8xi2> {
198227
%c0 = arith.constant 0 : i2
199228
%0 = vector.transfer_read %input[%col, %row], %c0 {in_bounds = [true]} : memref<4x8xi2>, vector<8xi2>
@@ -336,3 +365,22 @@ func.func @dealloc_strided_memref(%input: memref<4x8xf32, strided<[8, 1], offset
336365
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[8, 1], offset: 100>>)
337366
// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [1] : memref<4x8xf32, strided<[8, 1], offset: 100>> to memref<32xf32, strided<[1], offset: 100>>
338367
// CHECK-NEXT: memref.dealloc %[[REINT]] : memref<32xf32, strided<[1], offset: 100>>
368+
369+
// -----
370+
371+
memref.global "private" constant @constant_3x3x1x1xf32 : memref<3x3x1x1xf32> = dense<[[[[-1.000000e+00]], [[0.000000e+00]], [[1.000000e+00]]], [[[-2.000000e+00]], [[0.000000e+00]], [[2.000000e+00]]], [[[-1.000000e+00]], [[0.000000e+00]], [[1.000000e+00]]]]>
372+
func.func @load_global_with_offset(%i0: index, %i1: index, %i2: index, %i3: index) -> f32 {
373+
%global = memref.get_global @constant_3x3x1x1xf32 : memref<3x3x1x1xf32>
374+
%val = memref.load %global[%i0, %i1, %i2, %i3] : memref<3x3x1x1xf32>
375+
return %val: f32
376+
}
377+
378+
// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1, s2, s3] -> (s0 * 3 + s1 + s2 + s3)>
379+
// CHECK: memref.global "private" constant @constant_3x3x1x1xf32 : memref<9xf32> = dense<[-1.000000e+00, 0.000000e+00, 1.000000e+00, -2.000000e+00, 0.000000e+00, 2.000000e+00, -1.000000e+00, 0.000000e+00, 1.000000e+00]>
380+
//CHECK-LABEL: func.func @load_global_with_offset
381+
// CHECK-SAME: (%[[I0:.+]]: index, %[[I1:.+]]: index, %[[I2:.+]]: index, %[[I3:.+]]: index)
382+
// CHECK: %[[GLOBAL:.+]] = memref.get_global @constant_3x3x1x1xf32 : memref<9xf32>
383+
// CHECK: %[[INDEX:.+]] = affine.apply #[[$MAP]]()[%[[I0]], %[[I1]], %[[I2]], %[[I3]]]
384+
// CHECK: %[[REINTERPRET:.+]] = memref.reinterpret_cast %[[GLOBAL]] to offset: [0], sizes: [9], strides: [1] : memref<9xf32> to memref<9xf32, strided<[1]>>
385+
// CHECK: %[[LOAD:.+]] = memref.load %[[REINTERPRET]][%[[INDEX]]] : memref<9xf32, strided<[1]>>
386+
// CHECK: return %[[LOAD]]

0 commit comments

Comments
 (0)