Skip to content

Commit 79d2edf

Browse files
committed
ExpandShapeOp
1 parent f549e4f commit 79d2edf

File tree

2 files changed

+174
-1
lines changed

2 files changed

+174
-1
lines changed

mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,115 @@ static bool checkLayout(Value val) {
9696
isa<StridedLayoutAttr>(type.getLayout());
9797
}
9898

99+
/// Compute the expanded sizes of the given expand_shape for the reassociation
100+
/// group `groupId`. Portions adapted from
101+
/// `lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp` to avoid a direct
102+
/// dependency from the MemRef dialect on the Affine dialect.
103+
static SmallVector<OpFoldResult>
104+
getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
105+
ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
106+
ArrayRef<int64_t> reassocGroup =
107+
expandShape.getReassociationIndices()[groupId];
108+
assert(!reassocGroup.empty() &&
109+
"Reassociation group should have at least one dimension");
110+
111+
unsigned groupSize = reassocGroup.size();
112+
SmallVector<OpFoldResult> expandedSizes(groupSize);
113+
114+
uint64_t productOfAllStaticSizes = 1;
115+
std::optional<unsigned> dynSizeIdx;
116+
MemRefType expandShapeType = expandShape.getResultType();
117+
118+
for (unsigned i = 0; i < groupSize; ++i) {
119+
uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
120+
if (ShapedType::isDynamic(dimSize)) {
121+
assert(!dynSizeIdx && "there must be at most one dynamic size per group");
122+
dynSizeIdx = i;
123+
continue;
124+
}
125+
productOfAllStaticSizes *= dimSize;
126+
expandedSizes[i] = builder.getIndexAttr(dimSize);
127+
}
128+
129+
if (dynSizeIdx) {
130+
AffineExpr s0 = builder.getAffineSymbolExpr(0);
131+
expandedSizes[*dynSizeIdx] = affine::makeComposedFoldedAffineApply(
132+
builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes),
133+
origSizes[groupId]);
134+
}
135+
136+
return expandedSizes;
137+
}
138+
139+
/// Compute the expanded strides of the given expand_shape for the reassociation
140+
/// group `groupId`.
141+
static SmallVector<OpFoldResult>
142+
getExpandedStrides(memref::ExpandShapeOp expandShape, OpBuilder &builder,
143+
ArrayRef<OpFoldResult> origSizes,
144+
ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
145+
ArrayRef<int64_t> reassocGroup =
146+
expandShape.getReassociationIndices()[groupId];
147+
assert(!reassocGroup.empty() &&
148+
"Reassociation group should have at least one dimension");
149+
150+
unsigned groupSize = reassocGroup.size();
151+
MemRefType expandShapeType = expandShape.getResultType();
152+
153+
std::optional<int64_t> dynSizeIdx;
154+
uint64_t currentStride = 1;
155+
SmallVector<OpFoldResult> expandedStrides(groupSize);
156+
for (int i = groupSize - 1; i >= 0; --i) {
157+
expandedStrides[i] = builder.getIndexAttr(currentStride);
158+
uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
159+
if (ShapedType::isDynamic(dimSize)) {
160+
assert(!dynSizeIdx && "there must be at most one dynamic size per group");
161+
dynSizeIdx = i;
162+
continue;
163+
}
164+
currentStride *= dimSize;
165+
}
166+
167+
auto sourceType = expandShape.getSrcType();
168+
auto [strides, offset] = sourceType.getStridesAndOffset();
169+
(void)offset;
170+
171+
OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
172+
? origStrides[groupId]
173+
: builder.getIndexAttr(strides[groupId]);
174+
175+
int64_t doneStrideIdx = 0;
176+
if (dynSizeIdx) {
177+
int64_t productOfAllStaticSizes = currentStride;
178+
assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
179+
"dynamic reassociation must originate from dynamic source dim");
180+
OpFoldResult origSize = origSizes[groupId];
181+
182+
AffineExpr s0 = builder.getAffineSymbolExpr(0);
183+
AffineExpr s1 = builder.getAffineSymbolExpr(1);
184+
for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
185+
auto baseAttr = expandedStrides[doneStrideIdx].dyn_cast<Attribute>();
186+
assert(baseAttr && "expected attribute stride");
187+
int64_t baseExpandedStride = cast<IntegerAttr>(baseAttr).getInt();
188+
expandedStrides[doneStrideIdx] = affine::makeComposedFoldedAffineApply(
189+
builder, expandShape.getLoc(),
190+
(s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
191+
{origSize, origStride});
192+
}
193+
}
194+
195+
AffineExpr s0 = builder.getAffineSymbolExpr(0);
196+
for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
197+
auto baseAttr = expandedStrides[doneStrideIdx].dyn_cast<Attribute>();
198+
assert(baseAttr && "expected attribute stride");
199+
int64_t baseExpandedStride = cast<IntegerAttr>(baseAttr).getInt();
200+
expandedStrides[doneStrideIdx] = affine::makeComposedFoldedAffineApply(
201+
builder, expandShape.getLoc(), s0 * baseExpandedStride,
202+
{origStride});
203+
}
204+
205+
return expandedStrides;
206+
}
207+
99208
/// Produce an OpFoldResult representing the product of the values or constants
100209
/// referenced by `indices`. `staticShape` provides the statically known sizes
101210
/// for the source memref, while `values` contains the mixed (value/attribute)
@@ -426,6 +535,40 @@ struct FlattenCollapseShape final
426535
}
427536
};
428537

538+
struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp> {
539+
using OpRewritePattern::OpRewritePattern;
540+
541+
LogicalResult matchAndRewrite(memref::ExpandShapeOp op,
542+
PatternRewriter &rewriter) const override {
543+
Location loc = op.getLoc();
544+
memref::ExtractStridedMetadataOp metadata =
545+
memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
546+
547+
SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
548+
SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides();
549+
OpFoldResult offset = metadata.getConstifiedMixedOffset();
550+
551+
SmallVector<OpFoldResult> expandedSizes;
552+
SmallVector<OpFoldResult> expandedStrides;
553+
unsigned numGroups = op.getReassociationIndices().size();
554+
expandedSizes.reserve(op.getResultType().getRank());
555+
expandedStrides.reserve(op.getResultType().getRank());
556+
557+
for (unsigned i = 0; i < numGroups; ++i) {
558+
SmallVector<OpFoldResult> groupSizes =
559+
getExpandedSizes(op, rewriter, origSizes, i);
560+
SmallVector<OpFoldResult> groupStrides =
561+
getExpandedStrides(op, rewriter, origSizes, origStrides, i);
562+
expandedSizes.append(groupSizes.begin(), groupSizes.end());
563+
expandedStrides.append(groupStrides.begin(), groupStrides.end());
564+
}
565+
566+
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
567+
op, op.getType(), op.getSrc(), offset, expandedSizes, expandedStrides);
568+
return success();
569+
}
570+
};
571+
429572
struct FlattenMemrefsPass
430573
: public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
431574
using Base::Base;
@@ -501,6 +644,7 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
501644
MemRefRewritePattern<memref::AllocOp>,
502645
MemRefRewritePattern<memref::AllocaOp>,
503646
MemRefRewritePattern<memref::DeallocOp>,
647+
FlattenExpandShape,
504648
FlattenCollapseShape,
505649
FlattenGetGlobal,
506650
FlattenGlobal>(

mlir/test/Dialect/MemRef/flatten_memref.mlir

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@ func.func @collapse_shape_dynamic(
214214
return %0 : memref<?x4xf32, strided<[?, ?], offset: ?>>
215215
}
216216
// CHECK: #map = affine_map<()[s0] -> (s0 * 2)>
217-
// CHECK: #map1 = affine_map<()[s0, s1] -> (s0 * 8 + s1)>
218217
// CHECK-LABEL: func @collapse_shape_dynamic
219218
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %arg0
220219
// CHECK: %[[SIZE:.*]] = affine.apply #map()[%[[SIZES]]#1]
@@ -223,6 +222,36 @@ func.func @collapse_shape_dynamic(
223222

224223
// -----
225224

225+
func.func @expand_shape_static(%arg0: memref<6x4xf32>) -> memref<2x3x4xf32> {
226+
%0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 3, 4]
227+
: memref<6x4xf32> into memref<2x3x4xf32>
228+
return %0 : memref<2x3x4xf32>
229+
}
230+
// CHECK-LABEL: func @expand_shape_static
231+
// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [2, 3, 4], strides: [12, 4, 1]
232+
// CHECK: return %[[REINT]]
233+
234+
// -----
235+
236+
func.func @expand_shape_dynamic(
237+
%arg0: memref<?x4xf32, strided<[?, ?], offset: ?>>, %size: index) ->
238+
memref<?x3x4xf32, strided<[?, ?, ?], offset: ?>> {
239+
%0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%size, 3, 4]
240+
: memref<?x4xf32, strided<[?, ?], offset: ?>>
241+
into memref<?x3x4xf32, strided<[?, ?, ?], offset: ?>>
242+
return %0 : memref<?x3x4xf32, strided<[?, ?, ?], offset: ?>>
243+
}
244+
// CHECK: #map = affine_map<()[s0] -> (s0 floordiv 3)>
245+
// CHECK: #map1 = affine_map<()[s0] -> (s0 * 3)>
246+
// CHECK-LABEL: func @expand_shape_dynamic
247+
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %arg0
248+
// CHECK: %[[SIZE:.*]] = affine.apply #map()[%[[SIZES]]#0]
249+
// CHECK: %[[STRIDE:.*]] = affine.apply #map1()[%[[STRIDES]]#0]
250+
// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [%[[OFFSET]]], sizes: [%[[SIZE]], 3, 4], strides: [%[[STRIDE]], %[[STRIDES]]#0, %[[STRIDES]]#1]
251+
// CHECK: return %[[REINT]]
252+
253+
// -----
254+
226255
func.func @transfer_read_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) -> vector<8xi2> {
227256
%c0 = arith.constant 0 : i2
228257
%0 = vector.transfer_read %input[%col, %row], %c0 {in_bounds = [true]} : memref<4x8xi2>, vector<8xi2>

0 commit comments

Comments
 (0)