Skip to content

Commit 5b6b2ca

Browse files
committed
[mlir][vector] Handle memory space conflicts in VectorTransferSplit patterns
Currently the transfer splitting patterns will generate an invalid cast when the source memref for a transfer op has a non-default memory space. This is handled by first introducing a `memref.memory_space_cast` in such cases. Differential Revision: https://reviews.llvm.org/D154515
1 parent 0158d86 commit 5b6b2ca

File tree

2 files changed

+92
-12
lines changed

2 files changed

+92
-12
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,24 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
166166
StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
167167
}
168168

169+
/// Casts the given memref to a compatible memref type. If the source memref has
170+
/// a different address space than the target type, a `memref.memory_space_cast`
171+
/// is first inserted, followed by a `memref.cast`.
172+
static Value castToCompatibleMemRefType(OpBuilder &b, Value memref,
173+
MemRefType compatibleMemRefType) {
174+
MemRefType sourceType = memref.getType().cast<MemRefType>();
175+
Value res = memref;
176+
if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
177+
sourceType = MemRefType::get(
178+
sourceType.getShape(), sourceType.getElementType(),
179+
sourceType.getLayout(), compatibleMemRefType.getMemorySpace());
180+
res = b.create<memref::MemorySpaceCastOp>(memref.getLoc(), sourceType, res);
181+
}
182+
if (sourceType == compatibleMemRefType)
183+
return res;
184+
return b.create<memref::CastOp>(memref.getLoc(), compatibleMemRefType, res);
185+
}
186+
169187
/// Operates under a scoped context to build the intersection between the
170188
/// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
171189
// TODO: view intersection/union/differences should be a proper std op.
@@ -215,6 +233,7 @@ createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
215233
/// Produce IR resembling:
216234
/// ```
217235
/// %1:3 = scf.if (%inBounds) {
236+
/// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
218237
/// %view = memref.cast %A: memref<A...> to compatibleMemRefType
219238
/// scf.yield %view, ... : compatibleMemRefType, index, index
220239
/// } else {
@@ -237,9 +256,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
237256
return b.create<scf::IfOp>(
238257
loc, inBoundsCond,
239258
[&](OpBuilder &b, Location loc) {
240-
Value res = memref;
241-
if (compatibleMemRefType != xferOp.getShapedType())
242-
res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
259+
Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
243260
scf::ValueVector viewAndIndices{res};
244261
viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
245262
xferOp.getIndices().end());
@@ -256,7 +273,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
256273
alloc);
257274
b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
258275
Value casted =
259-
b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
276+
castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
260277
scf::ValueVector viewAndIndices{casted};
261278
viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
262279
zero);
@@ -270,6 +287,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
270287
/// Produce IR resembling:
271288
/// ```
272289
/// %1:3 = scf.if (%inBounds) {
290+
/// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
273291
/// memref.cast %A: memref<A...> to compatibleMemRefType
274292
/// scf.yield %view, ... : compatibleMemRefType, index, index
275293
/// } else {
@@ -292,9 +310,7 @@ static scf::IfOp createFullPartialVectorTransferRead(
292310
return b.create<scf::IfOp>(
293311
loc, inBoundsCond,
294312
[&](OpBuilder &b, Location loc) {
295-
Value res = memref;
296-
if (compatibleMemRefType != xferOp.getShapedType())
297-
res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
313+
Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
298314
scf::ValueVector viewAndIndices{res};
299315
viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
300316
xferOp.getIndices().end());
@@ -309,7 +325,7 @@ static scf::IfOp createFullPartialVectorTransferRead(
309325
loc, MemRefType::get({}, vector.getType()), alloc));
310326

311327
Value casted =
312-
b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
328+
castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
313329
scf::ValueVector viewAndIndices{casted};
314330
viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
315331
zero);
@@ -343,9 +359,8 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
343359
.create<scf::IfOp>(
344360
loc, inBoundsCond,
345361
[&](OpBuilder &b, Location loc) {
346-
Value res = memref;
347-
if (compatibleMemRefType != xferOp.getShapedType())
348-
res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
362+
Value res =
363+
castToCompatibleMemRefType(b, memref, compatibleMemRefType);
349364
scf::ValueVector viewAndIndices{res};
350365
viewAndIndices.insert(viewAndIndices.end(),
351366
xferOp.getIndices().begin(),
@@ -354,7 +369,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
354369
},
355370
[&](OpBuilder &b, Location loc) {
356371
Value casted =
357-
b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
372+
castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
358373
scf::ValueVector viewAndIndices{casted};
359374
viewAndIndices.insert(viewAndIndices.end(),
360375
xferOp.getTransferRank(), zero);

mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,37 @@ func.func @split_vector_transfer_read_strided_2d(
101101
return %1 : vector<4x8xf32>
102102
}
103103

104+
func.func @split_vector_transfer_read_mem_space(%A: memref<?x8xf32, 3>, %i: index, %j: index) -> vector<4x8xf32> {
105+
%c0 = arith.constant 0 : index
106+
%f0 = arith.constant 0.0 : f32
107+
108+
// CHECK: scf.if {{.*}} -> (memref<?x8xf32, strided<[8, 1]>>, index, index) {
109+
// inBounds with a different memory space
110+
// CHECK: %[[space_cast:.*]] = memref.memory_space_cast %{{.*}} :
111+
// CHECK-SAME: memref<?x8xf32, 3> to memref<?x8xf32>
112+
// CHECK: %[[cast:.*]] = memref.cast %[[space_cast]] :
113+
// CHECK-SAME: memref<?x8xf32> to memref<?x8xf32, strided<[8, 1]>>
114+
// CHECK: scf.yield %[[cast]], {{.*}} : memref<?x8xf32, strided<[8, 1]>>, index, index
115+
// CHECK: } else {
116+
// slow path, fill tmp alloc and yield a memref_casted version of it
117+
// CHECK: %[[slow:.*]] = vector.transfer_read %[[A]][%[[i]], %[[j]]], %cst :
118+
// CHECK-SAME: memref<?x8xf32, 3>, vector<4x8xf32>
119+
// CHECK: %[[cast_alloc:.*]] = vector.type_cast %[[alloc]] :
120+
// CHECK-SAME: memref<4x8xf32> to memref<vector<4x8xf32>>
121+
// CHECK: store %[[slow]], %[[cast_alloc]][] : memref<vector<4x8xf32>>
122+
// CHECK: %[[yielded:.*]] = memref.cast %[[alloc]] :
123+
// CHECK-SAME: memref<4x8xf32> to memref<?x8xf32, strided<[8, 1]>>
124+
// CHECK: scf.yield %[[yielded]], %[[c0]], %[[c0]] :
125+
// CHECK-SAME: memref<?x8xf32, strided<[8, 1]>>, index, index
126+
// CHECK: }
127+
// CHECK: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %cst
128+
// CHECK-SAME: {in_bounds = [true, true]} : memref<?x8xf32, strided<[8, 1]>>, vector<4x8xf32>
129+
130+
%1 = vector.transfer_read %A[%i, %j], %f0 : memref<?x8xf32, 3>, vector<4x8xf32>
131+
132+
return %1: vector<4x8xf32>
133+
}
134+
104135
transform.sequence failures(propagate) {
105136
^bb1(%func_op: !transform.op<"func.func">):
106137
transform.apply_patterns to %func_op {
@@ -228,6 +259,40 @@ transform.sequence failures(propagate) {
228259
} : !transform.op<"func.func">
229260
}
230261

262+
// -----
263+
264+
func.func @split_vector_transfer_write_mem_space(%V: vector<4x8xf32>, %A: memref<?x8xf32, 3>, %i: index, %j: index) {
265+
vector.transfer_write %V, %A[%i, %j] :
266+
vector<4x8xf32>, memref<?x8xf32, 3>
267+
return
268+
}
269+
270+
// CHECK: func @split_vector_transfer_write_mem_space(
271+
// CHECK: scf.if {{.*}} -> (memref<?x8xf32, strided<[8, 1]>>, index, index) {
272+
// CHECK: %[[space_cast:.*]] = memref.memory_space_cast %{{.*}} :
273+
// CHECK-SAME: memref<?x8xf32, 3> to memref<?x8xf32>
274+
// CHECK: %[[cast:.*]] = memref.cast %[[space_cast]] :
275+
// CHECK-SAME: memref<?x8xf32> to memref<?x8xf32, strided<[8, 1]>>
276+
// CHECK: scf.yield %[[cast]], {{.*}} : memref<?x8xf32, strided<[8, 1]>>, index, index
277+
// CHECK: } else {
278+
// CHECK: %[[VAL_15:.*]] = memref.cast %[[TEMP]]
279+
// CHECK-SAME: : memref<4x8xf32> to memref<?x8xf32, strided<[8, 1]>>
280+
// CHECK: scf.yield %[[VAL_15]], %[[C0]], %[[C0]]
281+
// CHECK-SAME: : memref<?x8xf32, strided<[8, 1]>>, index, index
282+
// CHECK: }
283+
// CHECK: vector.transfer_write %[[VEC]],
284+
// CHECK-SAME: %[[IN_BOUND_DEST:.*]]#0[%[[IN_BOUND_DEST]]#1, %[[IN_BOUND_DEST]]#2]
285+
// CHECK-SAME: {in_bounds = [true, true]} : vector<4x8xf32>, memref<?x8xf32, strided<[8, 1]>>
286+
287+
288+
transform.sequence failures(propagate) {
289+
^bb1(%func_op: !transform.op<"func.func">):
290+
transform.apply_patterns to %func_op {
291+
transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer"
292+
} : !transform.op<"func.func">
293+
}
294+
295+
231296
// -----
232297

233298
func.func private @fake_side_effecting_fun(%0: vector<2x2xf32>) -> ()

0 commit comments

Comments
 (0)