@@ -222,7 +222,7 @@ func.func @forward_dead_store_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
222222// `vector.transfer_write` would not be safe:
223223// %1 = vector.transfer_read %subview
224224// vector.transfer_write %1, %alloca
225- // vector.transfer_write %vec, %collapse_shape
225+ // vector.transfer_write %vec, %collapse_shape
226226// %2 = vector.transfer_read %alloca
227227// vector.transfer_write %1, %subview
228228// Indeed, %alloca and %collapse_shape alias and hence %2 != %1. Instead, the
@@ -360,3 +360,128 @@ func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
360360 vector.transfer_write %x , %buffer [%i0 , %i0 ] {in_bounds = [true ]} : vector <4 xf32 >, memref <?x?xf32 >
361361 return
362362}
363+
364+ // CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking
365+ // CHECK: %[[SPLAT:.*]] = arith.constant dense<0.000000e+00> : vector<[8]x[8]xf32>
366+ // CHECK-NOT: vector.transfer_write
367+ // CHECK-NOT: vector.transfer_read
368+ // CHECK: scf.for
369+ // CHECK-SAME: iter_args(%{{.*}} = %[[SPLAT]])
370+ // CHECK: }
371+ // CHECK: vector.transfer_write
372+ // CHECK: return
373+ func.func @forward_dead_constant_splat_store_with_masking (%buffer : memref <?x?xf32 >, %mask: vector <[8 ]x[8 ]xi1 >) {
374+ %zero_splat = arith.constant dense <0.0 > : vector <[8 ]x[8 ]xf32 >
375+ %read_padding = arith.constant 0.0 : f32
376+ %c1 = arith.constant 1 : index
377+ %c0 = arith.constant 0 : index
378+ %c512 = arith.constant 512 : index
379+ vector.transfer_write %zero_splat , %buffer [%c0 , %c0 ], %mask {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
380+ %0 = vector.transfer_read %buffer [%c0 , %c0 ], %read_padding , %mask {in_bounds = [true , true ]} : memref <?x?xf32 >, vector <[8 ]x[8 ]xf32 >
381+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args (%acc = %0 ) -> (vector <[8 ]x[8 ]xf32 >) {
382+ %1 = arith.addf %acc , %acc : vector <[8 ]x[8 ]xf32 >
383+ scf.yield %1 : vector <[8 ]x[8 ]xf32 >
384+ }
385+ vector.transfer_write %x , %buffer [%c0 , %c0 ], %mask {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
386+ return
387+ }
388+
389+ // Here the read can be eliminated but not the write (due to mismatched masks).
390+ // CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_unmasked_write
391+ // CHECK: %[[SPLAT:.*]] = arith.constant dense<0.000000e+00> : vector<[8]x[8]xf32>
392+ // CHECK: vector.transfer_write %[[SPLAT]]
393+ // CHECK: scf.for
394+ // CHECK-SAME: iter_args(%{{.*}} = %[[SPLAT]])
395+ // CHECK: }
396+ // CHECK: vector.transfer_write
397+ // CHECK: return
398+ func.func @forward_dead_constant_splat_store_with_masking_unmasked_write (%buffer : memref <?x?xf32 >, %mask: vector <[8 ]x[8 ]xi1 >) {
399+ %zero_splat = arith.constant dense <0.0 > : vector <[8 ]x[8 ]xf32 >
400+ %read_padding = arith.constant 0.0 : f32
401+ %c1 = arith.constant 1 : index
402+ %c0 = arith.constant 0 : index
403+ %c512 = arith.constant 512 : index
404+ vector.transfer_write %zero_splat , %buffer [%c0 , %c0 ] {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
405+ %0 = vector.transfer_read %buffer [%c0 , %c0 ], %read_padding , %mask {in_bounds = [true , true ]} : memref <?x?xf32 >, vector <[8 ]x[8 ]xf32 >
406+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args (%acc = %0 ) -> (vector <[8 ]x[8 ]xf32 >) {
407+ %1 = arith.addf %acc , %acc : vector <[8 ]x[8 ]xf32 >
408+ scf.yield %1 : vector <[8 ]x[8 ]xf32 >
409+ }
410+ vector.transfer_write %x , %buffer [%c0 , %c0 ], %mask {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
411+ return
412+ }
413+
414+ // Negative test, the padding does not match the constant splat, so we can't
415+ // forward the store.
416+ // CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative_0
417+ // CHECK: vector.transfer_write
418+ // CHECK: vector.transfer_read
419+ // CHECK: scf.for
420+ // CHECK: }
421+ // CHECK: vector.transfer_write
422+ // CHECK: return
423+ func.func @forward_dead_constant_splat_store_with_masking_negative_0 (%buffer : memref <?x?xf32 >, %mask: vector <[8 ]x[8 ]xi1 >) {
424+ %zero_splat = arith.constant dense <0.0 > : vector <[8 ]x[8 ]xf32 >
425+ %read_padding = arith.constant 1.0 : f32
426+ %c1 = arith.constant 1 : index
427+ %c0 = arith.constant 0 : index
428+ %c512 = arith.constant 512 : index
429+ vector.transfer_write %zero_splat , %buffer [%c0 , %c0 ], %mask {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
430+ %0 = vector.transfer_read %buffer [%c0 , %c0 ], %read_padding , %mask {in_bounds = [true , true ]} : memref <?x?xf32 >, vector <[8 ]x[8 ]xf32 >
431+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args (%acc = %0 ) -> (vector <[8 ]x[8 ]xf32 >) {
432+ %1 = arith.addf %acc , %acc : vector <[8 ]x[8 ]xf32 >
433+ scf.yield %1 : vector <[8 ]x[8 ]xf32 >
434+ }
435+ vector.transfer_write %x , %buffer [%c0 , %c0 ], %mask {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
436+ return
437+ }
438+
439+ // Negative test, the masks don't match between the read and write, so we can't
440+ // foward the store.
441+ // CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative_1
442+ // CHECK: vector.transfer_write
443+ // CHECK: vector.transfer_read
444+ // CHECK: scf.for
445+ // CHECK: }
446+ // CHECK: vector.transfer_write
447+ // CHECK: return
448+ func.func @forward_dead_constant_splat_store_with_masking_negative_1 (%buffer : memref <?x?xf32 >, %mask_a: vector <[8 ]x[8 ]xi1 >, %mask_b: vector <[8 ]x[8 ]xi1 >) {
449+ %zero_splat = arith.constant dense <0.0 > : vector <[8 ]x[8 ]xf32 >
450+ %read_padding = arith.constant 0.0 : f32
451+ %c1 = arith.constant 1 : index
452+ %c0 = arith.constant 0 : index
453+ %c512 = arith.constant 512 : index
454+ vector.transfer_write %zero_splat , %buffer [%c0 , %c0 ], %mask_a {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
455+ %0 = vector.transfer_read %buffer [%c0 , %c0 ], %read_padding , %mask_b {in_bounds = [true , true ]} : memref <?x?xf32 >, vector <[8 ]x[8 ]xf32 >
456+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args (%acc = %0 ) -> (vector <[8 ]x[8 ]xf32 >) {
457+ %1 = arith.addf %acc , %acc : vector <[8 ]x[8 ]xf32 >
458+ scf.yield %1 : vector <[8 ]x[8 ]xf32 >
459+ }
460+ vector.transfer_write %x , %buffer [%c0 , %c0 ], %mask_a {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
461+ return
462+ }
463+
464+ // Negative test, here the write is masked but the read is unmasked. We can't
465+ // forward the store (as the write could be of less elements then the read).
466+ // CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative_3
467+ // CHECK: vector.transfer_write
468+ // CHECK: vector.transfer_read
469+ // CHECK: scf.for
470+ // CHECK: }
471+ // CHECK: vector.transfer_write
472+ // CHECK: return
473+ func.func @forward_dead_constant_splat_store_with_masking_negative_3 (%buffer : memref <?x?xf32 >, %mask: vector <[8 ]x[8 ]xi1 >) {
474+ %zero_splat = arith.constant dense <0.0 > : vector <[8 ]x[8 ]xf32 >
475+ %read_padding = arith.constant 0.0 : f32
476+ %c1 = arith.constant 1 : index
477+ %c0 = arith.constant 0 : index
478+ %c512 = arith.constant 512 : index
479+ vector.transfer_write %zero_splat , %buffer [%c0 , %c0 ], %mask {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
480+ %0 = vector.transfer_read %buffer [%c0 , %c0 ], %read_padding {in_bounds = [true , true ]} : memref <?x?xf32 >, vector <[8 ]x[8 ]xf32 >
481+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args (%acc = %0 ) -> (vector <[8 ]x[8 ]xf32 >) {
482+ %1 = arith.addf %acc , %acc : vector <[8 ]x[8 ]xf32 >
483+ scf.yield %1 : vector <[8 ]x[8 ]xf32 >
484+ }
485+ vector.transfer_write %x , %buffer [%c0 , %c0 ], %mask {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
486+ return
487+ }
0 commit comments