@@ -103,8 +103,8 @@ static bool checkLayout(Value val) {
103103namespace {
104104static Value getTargetMemref (Operation *op) {
105105 return llvm::TypeSwitch<Operation *, Value>(op)
106- .template Case <memref::LoadOp, memref::StoreOp>(
107- [](auto op) { return op.getMemref (); })
106+ .template Case <memref::LoadOp, memref::StoreOp, memref::AllocaOp,
107+ memref::AllocOp>( [](auto op) { return op.getMemref (); })
108108 .template Case <vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
109109 vector::MaskedStoreOp>(
110110 [](auto op) { return op.getBase (); })
@@ -113,54 +113,78 @@ static Value getTargetMemref(Operation *op) {
113113 .Default ([](auto ) { return Value{}; });
114114}
115115
116- static void replaceOp (Operation *op, PatternRewriter &rewriter,
117- Value flatMemref, Value offset) {
116+ template <typename T>
117+ static void replaceOp (T op, PatternRewriter &rewriter, Value flatMemref,
118+ Value offset) {
118119 auto loc = op->getLoc ();
119- llvm::TypeSwitch<Operation *>(op)
120- .Case <memref::LoadOp>([&](auto op) {
120+ llvm::TypeSwitch<Operation *>(op.getOperation ())
121+ .template Case <memref::AllocOp>([&](auto oper) {
122+ // grab flatMemref's type, and replace op with a new one. Then
123+ // reinterpret it back.
124+ auto flatMemrefType = cast<MemRefType>(flatMemref.getType ());
125+ auto loc = oper.getLoc ();
126+ auto newAlloc = rewriter.create <memref::AllocOp>(
127+ loc, flatMemrefType, oper.getAlignmentAttr ());
128+ auto originalType = cast<MemRefType>(oper.getType ());
129+
130+ auto rank = originalType.getRank ();
131+ SmallVector<OpFoldResult, 4 > sizes, strides;
132+ sizes.resize (rank);
133+ strides.resize (rank);
134+ int64_t staticStride = 1 ;
135+ for (int i = rank - 1 ; i >= 0 ; --i) {
136+ sizes[i] = rewriter.getIndexAttr (originalType.getShape ()[i]);
137+ strides[i] = rewriter.getIndexAttr (staticStride);
138+ staticStride *= originalType.getShape ()[i];
139+ }
140+ rewriter.replaceOpWithNewOp <memref::ReinterpretCastOp>(
141+ op, originalType, newAlloc,
142+ /* offset=*/ rewriter.getIndexAttr (0 ), sizes, strides);
143+ })
144+ .template Case <memref::LoadOp>([&](auto op) {
121145 auto newLoad = rewriter.create <memref::LoadOp>(
122146 loc, op->getResultTypes (), flatMemref, ValueRange{offset});
123147 newLoad->setAttrs (op->getAttrs ());
124148 rewriter.replaceOp (op, newLoad.getResult ());
125149 })
126- .Case <memref::StoreOp>([&](auto op) {
150+ .template Case <memref::StoreOp>([&](auto op) {
127151 auto newStore = rewriter.create <memref::StoreOp>(
128152 loc, op->getOperands ().front (), flatMemref, ValueRange{offset});
129153 newStore->setAttrs (op->getAttrs ());
130154 rewriter.replaceOp (op, newStore);
131155 })
132- .Case <vector::LoadOp>([&](auto op) {
156+ .template Case <vector::LoadOp>([&](auto op) {
133157 auto newLoad = rewriter.create <vector::LoadOp>(
134158 loc, op->getResultTypes (), flatMemref, ValueRange{offset});
135159 newLoad->setAttrs (op->getAttrs ());
136160 rewriter.replaceOp (op, newLoad.getResult ());
137161 })
138- .Case <vector::StoreOp>([&](auto op) {
162+ .template Case <vector::StoreOp>([&](auto op) {
139163 auto newStore = rewriter.create <vector::StoreOp>(
140164 loc, op->getOperands ().front (), flatMemref, ValueRange{offset});
141165 newStore->setAttrs (op->getAttrs ());
142166 rewriter.replaceOp (op, newStore);
143167 })
144- .Case <vector::MaskedLoadOp>([&](auto op) {
168+ .template Case <vector::MaskedLoadOp>([&](auto op) {
145169 auto newMaskedLoad = rewriter.create <vector::MaskedLoadOp>(
146170 loc, op.getType (), flatMemref, ValueRange{offset}, op.getMask (),
147171 op.getPassThru ());
148172 newMaskedLoad->setAttrs (op->getAttrs ());
149173 rewriter.replaceOp (op, newMaskedLoad.getResult ());
150174 })
151- .Case <vector::MaskedStoreOp>([&](auto op) {
175+ .template Case <vector::MaskedStoreOp>([&](auto op) {
152176 auto newMaskedStore = rewriter.create <vector::MaskedStoreOp>(
153177 loc, flatMemref, ValueRange{offset}, op.getMask (),
154178 op.getValueToStore ());
155179 newMaskedStore->setAttrs (op->getAttrs ());
156180 rewriter.replaceOp (op, newMaskedStore);
157181 })
158- .Case <vector::TransferReadOp>([&](auto op) {
182+ .template Case <vector::TransferReadOp>([&](auto op) {
159183 auto newTransferRead = rewriter.create <vector::TransferReadOp>(
160184 loc, op.getType (), flatMemref, ValueRange{offset}, op.getPadding ());
161185 rewriter.replaceOp (op, newTransferRead.getResult ());
162186 })
163- .Case <vector::TransferWriteOp>([&](auto op) {
187+ .template Case <vector::TransferWriteOp>([&](auto op) {
164188 auto newTransferWrite = rewriter.create <vector::TransferWriteOp>(
165189 loc, op.getVector (), flatMemref, ValueRange{offset});
166190 rewriter.replaceOp (op, newTransferWrite);
@@ -170,6 +194,16 @@ static void replaceOp(Operation *op, PatternRewriter &rewriter,
170194 });
171195}
172196
197+ template <typename T>
198+ static ValueRange getIndices (T op) {
199+ if constexpr (std::is_same_v<T, memref::AllocaOp> ||
200+ std::is_same_v<T, memref::AllocOp>) {
201+ return ValueRange{};
202+ } else {
203+ return op.getIndices ();
204+ }
205+ }
206+
173207template <typename T>
174208struct MemRefRewritePattern : public OpRewritePattern <T> {
175209 using OpRewritePattern<T>::OpRewritePattern;
@@ -179,34 +213,42 @@ struct MemRefRewritePattern : public OpRewritePattern<T> {
179213 if (!needFlattening (memref) || !checkLayout (memref))
180214 return failure ();
181215 auto &&[flatMemref, offset] = getFlattenMemrefAndOffset (
182- rewriter, op->getLoc (), memref, op. getIndices ( ));
183- replaceOp (op, rewriter, flatMemref, offset);
216+ rewriter, op->getLoc (), memref, getIndices<T>(op ));
217+ replaceOp<T> (op, rewriter, flatMemref, offset);
184218 return success ();
185219 }
186220};
187221
188- // For any memref op that emits a new memref.
189- template <typename T>
190- struct MemRefSourceRewritePattern : public OpRewritePattern <T> {
191- using OpRewritePattern<T>::OpRewritePattern;
192- LogicalResult matchAndRewrite (T op,
193- PatternRewriter &rewriter) const override {
194- if (!needFlattening (op) || !checkLayout (op))
195- return failure ();
196- MemRefType sourceType = cast<MemRefType>(op.getType ());
197-
198- // Get flattened size, no strides.
199- auto dimSizes = llvm::to_vector (sourceType.getShape ());
200- auto flattenedSize = std::accumulate (
201- dimSizes.begin (), dimSizes.end (), 1 , std::multiplies<int64_t >());
202- auto flatMemrefType = MemRefType::get (
203- /* shape=*/ {flattenedSize}, sourceType.getElementType (),
204- /* layout=*/ nullptr , sourceType.getMemorySpace ());
205- rewriter.replaceOpWithNewOp <T>(
206- op, flatMemrefType);
207- return success ();
208- }
209- };
222+ // // For any memref op that emits a new memref.
223+ // template <typename T>
224+ // struct MemRefSourceRewritePattern : public OpRewritePattern<T> {
225+ // using OpRewritePattern<T>::OpRewritePattern;
226+ // LogicalResult matchAndRewrite(T op,
227+ // PatternRewriter &rewriter) const override {
228+ // if (!needFlattening(op) || !checkLayout(op))
229+ // return failure();
230+ // MemRefType sourceType = cast<MemRefType>(op.getType());
231+
232+ // auto mixedSizes = op.getMixedSizes();
233+
234+ // // Get flattened size, no strides.
235+ // auto flattenedSize = std::accumulate(
236+ // mixedSizes.begin(), mixedSizes.end(), 1,
237+ // [](int64_t a, OpFoldResult b) {
238+ // return a * getConstantIntValue(b).value_or(1);
239+ // });
240+
241+ // auto flatMemrefType = MemRefType::get(
242+ // /*shape=*/{flattenedSize}, sourceType.getElementType(),
243+ // /*layout=*/nullptr, sourceType.getMemorySpace());
244+ // auto newSource = rewriter.create<T>(
245+ // op.getLoc(), flatMemrefType, op.getDynamicSizes());
246+ // auto reinterpretCast = rewriter.create<memref::ReinterpretCastOp>(
247+ // op.getLoc(), sourceType, newSource, op.getOffset(),
248+ // op.getMixedSizes(), op.getStrides());
249+ // return success();
250+ // }
251+ // };
210252
211253struct FlattenMemrefsPass
212254 : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
@@ -232,8 +274,8 @@ struct FlattenMemrefsPass
232274void memref::populateFlattenMemrefsPatterns (RewritePatternSet &patterns) {
233275 patterns.insert <MemRefRewritePattern<memref::LoadOp>,
234276 MemRefRewritePattern<memref::StoreOp>,
235- MemRefSourceRewritePattern <memref::AllocOp>,
236- MemRefSourceRewritePattern <memref::AllocaOp>,
277+ MemRefRewritePattern <memref::AllocOp>,
278+ MemRefRewritePattern <memref::AllocaOp>,
237279 MemRefRewritePattern<vector::LoadOp>,
238280 MemRefRewritePattern<vector::StoreOp>,
239281 MemRefRewritePattern<vector::TransferReadOp>,
0 commit comments