@@ -166,6 +166,24 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
166
166
StridedLayoutAttr::get (aT.getContext (), resOffset, resStrides));
167
167
}
168
168
169
+ // / Casts the given memref to a compatible memref type. If the source memref has
170
+ // / a different address space than the target type, a `memref.memory_space_cast`
171
+ // / is first inserted, followed by a `memref.cast`.
172
+ static Value castToCompatibleMemRefType (OpBuilder &b, Value memref,
173
+ MemRefType compatibleMemRefType) {
174
+ MemRefType sourceType = memref.getType ().cast <MemRefType>();
175
+ Value res = memref;
176
+ if (sourceType.getMemorySpace () != compatibleMemRefType.getMemorySpace ()) {
177
+ sourceType = MemRefType::get (
178
+ sourceType.getShape (), sourceType.getElementType (),
179
+ sourceType.getLayout (), compatibleMemRefType.getMemorySpace ());
180
+ res = b.create <memref::MemorySpaceCastOp>(memref.getLoc (), sourceType, res);
181
+ }
182
+ if (sourceType == compatibleMemRefType)
183
+ return res;
184
+ return b.create <memref::CastOp>(memref.getLoc (), compatibleMemRefType, res);
185
+ }
186
+
169
187
// / Operates under a scoped context to build the intersection between the
170
188
// / view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
171
189
// TODO: view intersection/union/differences should be a proper std op.
@@ -215,6 +233,7 @@ createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
215
233
// / Produce IR resembling:
216
234
// / ```
217
235
// / %1:3 = scf.if (%inBounds) {
236
+ // / (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
218
237
// / %view = memref.cast %A: memref<A...> to compatibleMemRefType
219
238
// / scf.yield %view, ... : compatibleMemRefType, index, index
220
239
// / } else {
@@ -237,9 +256,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
237
256
return b.create <scf::IfOp>(
238
257
loc, inBoundsCond,
239
258
[&](OpBuilder &b, Location loc) {
240
- Value res = memref;
241
- if (compatibleMemRefType != xferOp.getShapedType ())
242
- res = b.create <memref::CastOp>(loc, compatibleMemRefType, memref);
259
+ Value res = castToCompatibleMemRefType (b, memref, compatibleMemRefType);
243
260
scf::ValueVector viewAndIndices{res};
244
261
viewAndIndices.insert (viewAndIndices.end (), xferOp.getIndices ().begin (),
245
262
xferOp.getIndices ().end ());
@@ -256,7 +273,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
256
273
alloc);
257
274
b.create <memref::CopyOp>(loc, copyArgs.first , copyArgs.second );
258
275
Value casted =
259
- b. create <memref::CastOp>(loc, compatibleMemRefType, alloc );
276
+ castToCompatibleMemRefType (b, alloc, compatibleMemRefType );
260
277
scf::ValueVector viewAndIndices{casted};
261
278
viewAndIndices.insert (viewAndIndices.end (), xferOp.getTransferRank (),
262
279
zero);
@@ -270,6 +287,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
270
287
// / Produce IR resembling:
271
288
// / ```
272
289
// / %1:3 = scf.if (%inBounds) {
290
+ // / (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
273
291
// / memref.cast %A: memref<A...> to compatibleMemRefType
274
292
// / scf.yield %view, ... : compatibleMemRefType, index, index
275
293
// / } else {
@@ -292,9 +310,7 @@ static scf::IfOp createFullPartialVectorTransferRead(
292
310
return b.create <scf::IfOp>(
293
311
loc, inBoundsCond,
294
312
[&](OpBuilder &b, Location loc) {
295
- Value res = memref;
296
- if (compatibleMemRefType != xferOp.getShapedType ())
297
- res = b.create <memref::CastOp>(loc, compatibleMemRefType, memref);
313
+ Value res = castToCompatibleMemRefType (b, memref, compatibleMemRefType);
298
314
scf::ValueVector viewAndIndices{res};
299
315
viewAndIndices.insert (viewAndIndices.end (), xferOp.getIndices ().begin (),
300
316
xferOp.getIndices ().end ());
@@ -309,7 +325,7 @@ static scf::IfOp createFullPartialVectorTransferRead(
309
325
loc, MemRefType::get ({}, vector.getType ()), alloc));
310
326
311
327
Value casted =
312
- b. create <memref::CastOp>(loc, compatibleMemRefType, alloc );
328
+ castToCompatibleMemRefType (b, alloc, compatibleMemRefType );
313
329
scf::ValueVector viewAndIndices{casted};
314
330
viewAndIndices.insert (viewAndIndices.end (), xferOp.getTransferRank (),
315
331
zero);
@@ -343,9 +359,8 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
343
359
.create <scf::IfOp>(
344
360
loc, inBoundsCond,
345
361
[&](OpBuilder &b, Location loc) {
346
- Value res = memref;
347
- if (compatibleMemRefType != xferOp.getShapedType ())
348
- res = b.create <memref::CastOp>(loc, compatibleMemRefType, memref);
362
+ Value res =
363
+ castToCompatibleMemRefType (b, memref, compatibleMemRefType);
349
364
scf::ValueVector viewAndIndices{res};
350
365
viewAndIndices.insert (viewAndIndices.end (),
351
366
xferOp.getIndices ().begin (),
@@ -354,7 +369,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
354
369
},
355
370
[&](OpBuilder &b, Location loc) {
356
371
Value casted =
357
- b. create <memref::CastOp>(loc, compatibleMemRefType, alloc );
372
+ castToCompatibleMemRefType (b, alloc, compatibleMemRefType );
358
373
scf::ValueVector viewAndIndices{casted};
359
374
viewAndIndices.insert (viewAndIndices.end (),
360
375
xferOp.getTransferRank (), zero);
0 commit comments