Skip to content

Commit ca02f36

Browse files
authored
[mlir][vector] Teach TransferOptimization to forward masked stores (llvm#87794)
This only handles one case (that's fairly common in practice*), storing a masked constant splat, then reloading again with the same mask and a padding value that matches the splat. * For SVE/SME (without peeling) this occurs when you have a `linalg.fill` preceding a `linalg.matmul`.
1 parent 72a8953 commit ca02f36

File tree

2 files changed

+160
-4
lines changed

2 files changed

+160
-4
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,43 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
170170
shapedType.getContext());
171171
}
172172

173+
/// Check if `write` is of a constant splat and the masked `read` is padded with
174+
/// the same splat value -- meaning it could be the same value as the initial
175+
/// constant splat.
176+
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
177+
vector::TransferReadOp read) {
178+
auto readMask = read.getMask();
179+
auto writeMask = write.getMask();
180+
// Check if the masks are consistent. The splat value could be the same if the
181+
// read is masked (and padded with the splat value), and the write is unmasked
182+
// or has the same mask. Note this does not allow the case where the write is
183+
// masked and the read is unmasked, as then the read could be of more elements
184+
// than the write (which may not be the same value).
185+
bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
186+
if (!couldBeSameSplat)
187+
return false;
188+
// Check for constant splat (as the source of the write).
189+
DenseElementsAttr splatAttr;
190+
if (!matchPattern(write.getVector(),
191+
m_Constant<DenseElementsAttr>(&splatAttr)) ||
192+
!splatAttr.isSplat()) {
193+
return false;
194+
}
195+
// The padding of the read and the constant splat value must be the same.
196+
Attribute padAttr;
197+
if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
198+
return false;
199+
return padAttr == splatAttr.getSplatValue<Attribute>();
200+
}
201+
173202
bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
174203
vector::TransferReadOp read) {
175-
return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
176-
!read.getMask() && defWrite.getIndices() == read.getIndices() &&
204+
return !defWrite.hasOutOfBoundsDim() &&
205+
defWrite.getIndices() == read.getIndices() &&
177206
defWrite.getVectorType() == read.getVectorType() &&
178-
defWrite.getPermutationMap() == read.getPermutationMap();
207+
defWrite.getPermutationMap() == read.getPermutationMap() &&
208+
((!defWrite.getMask() && !read.getMask()) ||
209+
isSplatWriteConsistentWithMaskedRead(defWrite, read));
179210
}
180211

181212
bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,

mlir/test/Dialect/Vector/vector-transferop-opt.mlir

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ func.func @forward_dead_store_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
222222
// `vector.transfer_write` would not be safe:
223223
// %1 = vector.transfer_read %subview
224224
// vector.transfer_write %1, %alloca
225-
// vector.transfer_write %vec, %collapse_shape
225+
// vector.transfer_write %vec, %collapse_shape
226226
// %2 = vector.transfer_read %alloca
227227
// vector.transfer_write %1, %subview
228228
// Indeed, %alloca and %collapse_shape alias and hence %2 != %1. Instead, the
@@ -360,3 +360,128 @@ func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
360360
vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
361361
return
362362
}
363+
364+
// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking
365+
// CHECK: %[[SPLAT:.*]] = arith.constant dense<0.000000e+00> : vector<[8]x[8]xf32>
366+
// CHECK-NOT: vector.transfer_write
367+
// CHECK-NOT: vector.transfer_read
368+
// CHECK: scf.for
369+
// CHECK-SAME: iter_args(%{{.*}} = %[[SPLAT]])
370+
// CHECK: }
371+
// CHECK: vector.transfer_write
372+
// CHECK: return
373+
func.func @forward_dead_constant_splat_store_with_masking(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
374+
%zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
375+
%read_padding = arith.constant 0.0 : f32
376+
%c1 = arith.constant 1 : index
377+
%c0 = arith.constant 0 : index
378+
%c512 = arith.constant 512 : index
379+
vector.transfer_write %zero_splat, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
380+
%0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
381+
%x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
382+
%1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
383+
scf.yield %1 : vector<[8]x[8]xf32>
384+
}
385+
vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
386+
return
387+
}
388+
389+
// Here the read can be eliminated but not the write (due to mismatched masks).
390+
// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_unmasked_write
391+
// CHECK: %[[SPLAT:.*]] = arith.constant dense<0.000000e+00> : vector<[8]x[8]xf32>
392+
// CHECK: vector.transfer_write %[[SPLAT]]
393+
// CHECK: scf.for
394+
// CHECK-SAME: iter_args(%{{.*}} = %[[SPLAT]])
395+
// CHECK: }
396+
// CHECK: vector.transfer_write
397+
// CHECK: return
398+
func.func @forward_dead_constant_splat_store_with_masking_unmasked_write(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
399+
%zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
400+
%read_padding = arith.constant 0.0 : f32
401+
%c1 = arith.constant 1 : index
402+
%c0 = arith.constant 0 : index
403+
%c512 = arith.constant 512 : index
404+
vector.transfer_write %zero_splat, %buffer[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
405+
%0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
406+
%x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
407+
%1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
408+
scf.yield %1 : vector<[8]x[8]xf32>
409+
}
410+
vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
411+
return
412+
}
413+
414+
// Negative test, the padding does not match the constant splat, so we can't
415+
// forward the store.
416+
// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative_0
417+
// CHECK: vector.transfer_write
418+
// CHECK: vector.transfer_read
419+
// CHECK: scf.for
420+
// CHECK: }
421+
// CHECK: vector.transfer_write
422+
// CHECK: return
423+
func.func @forward_dead_constant_splat_store_with_masking_negative_0(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
424+
%zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
425+
%read_padding = arith.constant 1.0 : f32
426+
%c1 = arith.constant 1 : index
427+
%c0 = arith.constant 0 : index
428+
%c512 = arith.constant 512 : index
429+
vector.transfer_write %zero_splat, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
430+
%0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
431+
%x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
432+
%1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
433+
scf.yield %1 : vector<[8]x[8]xf32>
434+
}
435+
vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
436+
return
437+
}
438+
439+
// Negative test, the masks don't match between the read and write, so we can't
440+
// foward the store.
441+
// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative_1
442+
// CHECK: vector.transfer_write
443+
// CHECK: vector.transfer_read
444+
// CHECK: scf.for
445+
// CHECK: }
446+
// CHECK: vector.transfer_write
447+
// CHECK: return
448+
func.func @forward_dead_constant_splat_store_with_masking_negative_1(%buffer : memref<?x?xf32>, %mask_a: vector<[8]x[8]xi1>, %mask_b: vector<[8]x[8]xi1>) {
449+
%zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
450+
%read_padding = arith.constant 0.0 : f32
451+
%c1 = arith.constant 1 : index
452+
%c0 = arith.constant 0 : index
453+
%c512 = arith.constant 512 : index
454+
vector.transfer_write %zero_splat, %buffer[%c0, %c0], %mask_a {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
455+
%0 = vector.transfer_read %buffer[%c0, %c0], %read_padding, %mask_b {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
456+
%x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
457+
%1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
458+
scf.yield %1 : vector<[8]x[8]xf32>
459+
}
460+
vector.transfer_write %x, %buffer[%c0, %c0], %mask_a {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
461+
return
462+
}
463+
464+
// Negative test, here the write is masked but the read is unmasked. We can't
465+
// forward the store (as the write could be of less elements then the read).
466+
// CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative_3
467+
// CHECK: vector.transfer_write
468+
// CHECK: vector.transfer_read
469+
// CHECK: scf.for
470+
// CHECK: }
471+
// CHECK: vector.transfer_write
472+
// CHECK: return
473+
func.func @forward_dead_constant_splat_store_with_masking_negative_3(%buffer : memref<?x?xf32>, %mask: vector<[8]x[8]xi1>) {
474+
%zero_splat = arith.constant dense<0.0> : vector<[8]x[8]xf32>
475+
%read_padding = arith.constant 0.0 : f32
476+
%c1 = arith.constant 1 : index
477+
%c0 = arith.constant 0 : index
478+
%c512 = arith.constant 512 : index
479+
vector.transfer_write %zero_splat, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
480+
%0 = vector.transfer_read %buffer[%c0, %c0], %read_padding {in_bounds = [true, true]} : memref<?x?xf32>, vector<[8]x[8]xf32>
481+
%x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args(%acc = %0) -> (vector<[8]x[8]xf32>) {
482+
%1 = arith.addf %acc, %acc : vector<[8]x[8]xf32>
483+
scf.yield %1 : vector<[8]x[8]xf32>
484+
}
485+
vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
486+
return
487+
}

0 commit comments

Comments
 (0)