@@ -51,31 +51,6 @@ static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
5151 return cast<Value>(in);
5252}
5353
54- // / Given dimension size [d1, d2, ...] and strides [s1, s2, ...], compute the
55- // / span of the memref.
56- static OpFoldResult computeSize (OpBuilder &builder, Location loc,
57- ArrayRef<OpFoldResult> dims,
58- ArrayRef<OpFoldResult> strides) {
59- assert (dims.size () == strides.size () &&
60- " number of dimensions and strides should be equal" );
61- SmallVector<AffineExpr> symbols (2 * dims.size ());
62- bindSymbolsList (builder.getContext (), MutableArrayRef{symbols});
63- SmallVector<AffineExpr> productExpressions;
64- SmallVector<OpFoldResult> values;
65- size_t symbolIndex = 0 ;
66- for (auto &&[dim, stride] : llvm::zip (dims, strides)) {
67- AffineExpr dimExpr = symbols[symbolIndex++];
68- AffineExpr strideExpr = symbols[symbolIndex++];
69- productExpressions.push_back (dimExpr * strideExpr);
70- values.push_back (dim);
71- values.push_back (stride);
72- }
73-
74- AffineMap maxMap = AffineMap::get (0 , symbols.size (), productExpressions,
75- builder.getContext ());
76- return affine::makeComposedFoldedAffineMax (builder, loc, maxMap, values);
77- }
78-
7954// / Returns a collapsed memref and the linearized index to access the element
8055// / at the specified indices.
8156static std::pair<Value, Value> getFlattenMemrefAndOffset (OpBuilder &rewriter,
@@ -108,9 +83,7 @@ static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
10883 loc, source,
10984 /* offset = */ linearizedInfo.linearizedOffset ,
11085 /* shapes = */
111- ArrayRef<OpFoldResult>{computeSize (
112- rewriter, loc, stridedMetadata.getConstifiedMixedSizes (),
113- stridedMetadata.getConstifiedMixedStrides ())},
86+ ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize },
11487 /* strides = */
11588 ArrayRef<OpFoldResult>{rewriter.getIndexAttr (1 )}),
11689 getValueFromOpFoldResult (rewriter, loc, linearizedIndices));
@@ -133,16 +106,15 @@ static Value getTargetMemref(Operation *op) {
133106 .template Case <memref::LoadOp, memref::StoreOp, memref::AllocaOp,
134107 memref::AllocOp>([](auto op) { return op.getMemref (); })
135108 .template Case <vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
136- vector::MaskedStoreOp>(
109+ vector::MaskedStoreOp, vector::TransferReadOp,
110+ vector::TransferWriteOp>(
137111 [](auto op) { return op.getBase (); })
138- .template Case <vector::TransferReadOp, vector::TransferWriteOp>(
139- [](auto op) { return op.getSource (); })
140112 .Default ([](auto ) { return Value{}; });
141113}
142114
143115template <typename T>
144- static void castResult (T oper, T newOper, Location loc,
145- PatternRewriter &rewriter) {
116+ static void castAllocResult (T oper, T newOper, Location loc,
117+ PatternRewriter &rewriter) {
146118 memref::ExtractStridedMetadataOp stridedMetadata =
147119 rewriter.create <memref::ExtractStridedMetadataOp>(loc, oper);
148120 rewriter.replaceOpWithNewOp <memref::ReinterpretCastOp>(
@@ -155,19 +127,19 @@ static void castResult(T oper, T newOper, Location loc,
155127template <typename T>
156128static void replaceOp (T op, PatternRewriter &rewriter, Value flatMemref,
157129 Value offset) {
158- auto loc = op->getLoc ();
130+ Location loc = op->getLoc ();
159131 llvm::TypeSwitch<Operation *>(op.getOperation ())
160132 .template Case <memref::AllocOp>([&](auto oper) {
161133 auto newAlloc = rewriter.create <memref::AllocOp>(
162134 loc, cast<MemRefType>(flatMemref.getType ()),
163135 oper.getAlignmentAttr ());
164- castResult (oper, newAlloc, loc, rewriter);
136+ castAllocResult (oper, newAlloc, loc, rewriter);
165137 })
166138 .template Case <memref::AllocaOp>([&](auto oper) {
167139 auto newAlloca = rewriter.create <memref::AllocaOp>(
168140 loc, cast<MemRefType>(flatMemref.getType ()),
169141 oper.getAlignmentAttr ());
170- castResult (oper, newAlloca, loc, rewriter);
142+ castAllocResult (oper, newAlloca, loc, rewriter);
171143 })
172144 .template Case <memref::LoadOp>([&](auto op) {
173145 auto newLoad = rewriter.create <memref::LoadOp>(
@@ -232,11 +204,42 @@ static ValueRange getIndices(T op) {
232204 }
233205}
234206
207+ template <typename T>
208+ static LogicalResult canBeFlattened (T op, PatternRewriter &rewriter) {
209+ return llvm::TypeSwitch<Operation *, LogicalResult>(op.getOperation ())
210+ .template Case <vector::TransferReadOp, vector::TransferWriteOp>(
211+ [&](auto oper) {
212+ // For vector.transfer_read/write, must make sure:
213+ // 1. all accesses are inbound, and
214+ // 2. has an identity or minor identity permutation map.
215+ auto permutationMap = oper.getPermutationMap ();
216+ if (!permutationMap.isIdentity () &&
217+ !permutationMap.isMinorIdentity ()) {
218+ return rewriter.notifyMatchFailure (
219+ oper, " only identity permutation map is supported" );
220+ }
221+ mlir::ArrayAttr inbounds = oper.getInBounds ();
222+ if (llvm::any_of (inbounds, [](Attribute attr) {
223+ return !cast<BoolAttr>(attr).getValue ();
224+ })) {
225+ return rewriter.notifyMatchFailure (oper,
226+ " only inbounds are supported" );
227+ }
228+ return success ();
229+ })
230+ .Default ([&](auto op) { return success (); });
231+ }
232+
235233template <typename T>
236234struct MemRefRewritePattern : public OpRewritePattern <T> {
237235 using OpRewritePattern<T>::OpRewritePattern;
238236 LogicalResult matchAndRewrite (T op,
239237 PatternRewriter &rewriter) const override {
238+ LogicalResult canFlatten = canBeFlattened (op, rewriter);
239+ if (failed (canFlatten)) {
240+ return canFlatten;
241+ }
242+
240243 Value memref = getTargetMemref (op);
241244 if (!needFlattening (memref) || !checkLayout (memref))
242245 return failure ();
0 commit comments