@@ -96,6 +96,115 @@ static bool checkLayout(Value val) {
9696 isa<StridedLayoutAttr>(type.getLayout ());
9797}
9898
99+ // / Compute the expanded sizes of the given expand_shape for the reassociation
100+ // / group `groupId`. Portions adapted from
101+ // / `lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp` to avoid a direct
102+ // / dependency from the MemRef dialect on the Affine dialect.
103+ static SmallVector<OpFoldResult>
104+ getExpandedSizes (memref::ExpandShapeOp expandShape, OpBuilder &builder,
105+ ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
106+ ArrayRef<int64_t > reassocGroup =
107+ expandShape.getReassociationIndices ()[groupId];
108+ assert (!reassocGroup.empty () &&
109+ " Reassociation group should have at least one dimension" );
110+
111+ unsigned groupSize = reassocGroup.size ();
112+ SmallVector<OpFoldResult> expandedSizes (groupSize);
113+
114+ uint64_t productOfAllStaticSizes = 1 ;
115+ std::optional<unsigned > dynSizeIdx;
116+ MemRefType expandShapeType = expandShape.getResultType ();
117+
118+ for (unsigned i = 0 ; i < groupSize; ++i) {
119+ uint64_t dimSize = expandShapeType.getDimSize (reassocGroup[i]);
120+ if (ShapedType::isDynamic (dimSize)) {
121+ assert (!dynSizeIdx && " there must be at most one dynamic size per group" );
122+ dynSizeIdx = i;
123+ continue ;
124+ }
125+ productOfAllStaticSizes *= dimSize;
126+ expandedSizes[i] = builder.getIndexAttr (dimSize);
127+ }
128+
129+ if (dynSizeIdx) {
130+ AffineExpr s0 = builder.getAffineSymbolExpr (0 );
131+ expandedSizes[*dynSizeIdx] = affine::makeComposedFoldedAffineApply (
132+ builder, expandShape.getLoc (), s0.floorDiv (productOfAllStaticSizes),
133+ origSizes[groupId]);
134+ }
135+
136+ return expandedSizes;
137+ }
138+
139+ // / Compute the expanded strides of the given expand_shape for the reassociation
140+ // / group `groupId`.
141+ static SmallVector<OpFoldResult>
142+ getExpandedStrides (memref::ExpandShapeOp expandShape, OpBuilder &builder,
143+ ArrayRef<OpFoldResult> origSizes,
144+ ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
145+ ArrayRef<int64_t > reassocGroup =
146+ expandShape.getReassociationIndices ()[groupId];
147+ assert (!reassocGroup.empty () &&
148+ " Reassociation group should have at least one dimension" );
149+
150+ unsigned groupSize = reassocGroup.size ();
151+ MemRefType expandShapeType = expandShape.getResultType ();
152+
153+ std::optional<int64_t > dynSizeIdx;
154+ uint64_t currentStride = 1 ;
155+ SmallVector<OpFoldResult> expandedStrides (groupSize);
156+ for (int i = groupSize - 1 ; i >= 0 ; --i) {
157+ expandedStrides[i] = builder.getIndexAttr (currentStride);
158+ uint64_t dimSize = expandShapeType.getDimSize (reassocGroup[i]);
159+ if (ShapedType::isDynamic (dimSize)) {
160+ assert (!dynSizeIdx && " there must be at most one dynamic size per group" );
161+ dynSizeIdx = i;
162+ continue ;
163+ }
164+ currentStride *= dimSize;
165+ }
166+
167+ auto sourceType = expandShape.getSrcType ();
168+ auto [strides, offset] = sourceType.getStridesAndOffset ();
169+ (void )offset;
170+
171+ OpFoldResult origStride = ShapedType::isDynamic (strides[groupId])
172+ ? origStrides[groupId]
173+ : builder.getIndexAttr (strides[groupId]);
174+
175+ int64_t doneStrideIdx = 0 ;
176+ if (dynSizeIdx) {
177+ int64_t productOfAllStaticSizes = currentStride;
178+ assert (ShapedType::isDynamic (sourceType.getDimSize (groupId)) &&
179+ " dynamic reassociation must originate from dynamic source dim" );
180+ OpFoldResult origSize = origSizes[groupId];
181+
182+ AffineExpr s0 = builder.getAffineSymbolExpr (0 );
183+ AffineExpr s1 = builder.getAffineSymbolExpr (1 );
184+ for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
185+ auto baseAttr = expandedStrides[doneStrideIdx].dyn_cast <Attribute>();
186+ assert (baseAttr && " expected attribute stride" );
187+ int64_t baseExpandedStride = cast<IntegerAttr>(baseAttr).getInt ();
188+ expandedStrides[doneStrideIdx] = affine::makeComposedFoldedAffineApply (
189+ builder, expandShape.getLoc (),
190+ (s0 * baseExpandedStride).floorDiv (productOfAllStaticSizes) * s1,
191+ {origSize, origStride});
192+ }
193+ }
194+
195+ AffineExpr s0 = builder.getAffineSymbolExpr (0 );
196+ for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
197+ auto baseAttr = expandedStrides[doneStrideIdx].dyn_cast <Attribute>();
198+ assert (baseAttr && " expected attribute stride" );
199+ int64_t baseExpandedStride = cast<IntegerAttr>(baseAttr).getInt ();
200+ expandedStrides[doneStrideIdx] = affine::makeComposedFoldedAffineApply (
201+ builder, expandShape.getLoc (), s0 * baseExpandedStride,
202+ {origStride});
203+ }
204+
205+ return expandedStrides;
206+ }
207+
99208// / Produce an OpFoldResult representing the product of the values or constants
100209// / referenced by `indices`. `staticShape` provides the statically known sizes
101210// / for the source memref, while `values` contains the mixed (value/attribute)
@@ -426,6 +535,40 @@ struct FlattenCollapseShape final
426535 }
427536};
428537
538+ struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp> {
539+ using OpRewritePattern::OpRewritePattern;
540+
541+ LogicalResult matchAndRewrite (memref::ExpandShapeOp op,
542+ PatternRewriter &rewriter) const override {
543+ Location loc = op.getLoc ();
544+ memref::ExtractStridedMetadataOp metadata =
545+ memref::ExtractStridedMetadataOp::create (rewriter, loc, op.getSrc ());
546+
547+ SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes ();
548+ SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides ();
549+ OpFoldResult offset = metadata.getConstifiedMixedOffset ();
550+
551+ SmallVector<OpFoldResult> expandedSizes;
552+ SmallVector<OpFoldResult> expandedStrides;
553+ unsigned numGroups = op.getReassociationIndices ().size ();
554+ expandedSizes.reserve (op.getResultType ().getRank ());
555+ expandedStrides.reserve (op.getResultType ().getRank ());
556+
557+ for (unsigned i = 0 ; i < numGroups; ++i) {
558+ SmallVector<OpFoldResult> groupSizes =
559+ getExpandedSizes (op, rewriter, origSizes, i);
560+ SmallVector<OpFoldResult> groupStrides =
561+ getExpandedStrides (op, rewriter, origSizes, origStrides, i);
562+ expandedSizes.append (groupSizes.begin (), groupSizes.end ());
563+ expandedStrides.append (groupStrides.begin (), groupStrides.end ());
564+ }
565+
566+ rewriter.replaceOpWithNewOp <memref::ReinterpretCastOp>(
567+ op, op.getType (), op.getSrc (), offset, expandedSizes, expandedStrides);
568+ return success ();
569+ }
570+ };
571+
429572struct FlattenMemrefsPass
430573 : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
431574 using Base::Base;
@@ -501,6 +644,7 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
501644 MemRefRewritePattern<memref::AllocOp>,
502645 MemRefRewritePattern<memref::AllocaOp>,
503646 MemRefRewritePattern<memref::DeallocOp>,
647+ FlattenExpandShape,
504648 FlattenCollapseShape,
505649 FlattenGetGlobal,
506650 FlattenGlobal>(
0 commit comments