@@ -222,7 +222,7 @@ func.func @forward_dead_store_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
222
222
// `vector.transfer_write` would not be safe:
223
223
// %1 = vector.transfer_read %subview
224
224
// vector.transfer_write %1, %alloca
225
- // vector.transfer_write %vec, %collapse_shape
225
+ // vector.transfer_write %vec, %collapse_shape
226
226
// %2 = vector.transfer_read %alloca
227
227
// vector.transfer_write %1, %subview
228
228
// Indeed, %alloca and %collapse_shape alias and hence %2 != %1. Instead, the
@@ -360,3 +360,128 @@ func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
360
360
vector.transfer_write %x , %buffer [%i0 , %i0 ] {in_bounds = [true ]} : vector <4 xf32 >, memref <?x?xf32 >
361
361
return
362
362
}
363
+
364
+ // CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking
365
+ // CHECK: %[[SPLAT:.*]] = arith.constant dense<0.000000e+00> : vector<[8]x[8]xf32>
366
+ // CHECK-NOT: vector.transfer_write
367
+ // CHECK-NOT: vector.transfer_read
368
+ // CHECK: scf.for
369
+ // CHECK-SAME: iter_args(%{{.*}} = %[[SPLAT]])
370
+ // CHECK: }
371
+ // CHECK: vector.transfer_write
372
+ // CHECK: return
373
+ func.func @forward_dead_constant_splat_store_with_masking (%buffer : memref <?x?xf32 >, %mask: vector <[8 ]x[8 ]xi1 >) {
374
+ %zero_splat = arith.constant dense <0.0 > : vector <[8 ]x[8 ]xf32 >
375
+ %read_padding = arith.constant 0.0 : f32
376
+ %c1 = arith.constant 1 : index
377
+ %c0 = arith.constant 0 : index
378
+ %c512 = arith.constant 512 : index
379
+ vector.transfer_write %zero_splat , %buffer [%c0 , %c0 ], %mask {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
380
+ %0 = vector.transfer_read %buffer [%c0 , %c0 ], %read_padding , %mask {in_bounds = [true , true ]} : memref <?x?xf32 >, vector <[8 ]x[8 ]xf32 >
381
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args (%acc = %0 ) -> (vector <[8 ]x[8 ]xf32 >) {
382
+ %1 = arith.addf %acc , %acc : vector <[8 ]x[8 ]xf32 >
383
+ scf.yield %1 : vector <[8 ]x[8 ]xf32 >
384
+ }
385
+ vector.transfer_write %x , %buffer [%c0 , %c0 ], %mask {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
386
+ return
387
+ }
388
+
389
+ // Here the read can be eliminated but not the write (due to mismatched masks).
390
+ // CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_unmasked_write
391
+ // CHECK: %[[SPLAT:.*]] = arith.constant dense<0.000000e+00> : vector<[8]x[8]xf32>
392
+ // CHECK: vector.transfer_write %[[SPLAT]]
393
+ // CHECK: scf.for
394
+ // CHECK-SAME: iter_args(%{{.*}} = %[[SPLAT]])
395
+ // CHECK: }
396
+ // CHECK: vector.transfer_write
397
+ // CHECK: return
398
+ func.func @forward_dead_constant_splat_store_with_masking_unmasked_write (%buffer : memref <?x?xf32 >, %mask: vector <[8 ]x[8 ]xi1 >) {
399
+ %zero_splat = arith.constant dense <0.0 > : vector <[8 ]x[8 ]xf32 >
400
+ %read_padding = arith.constant 0.0 : f32
401
+ %c1 = arith.constant 1 : index
402
+ %c0 = arith.constant 0 : index
403
+ %c512 = arith.constant 512 : index
404
+ vector.transfer_write %zero_splat , %buffer [%c0 , %c0 ] {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
405
+ %0 = vector.transfer_read %buffer [%c0 , %c0 ], %read_padding , %mask {in_bounds = [true , true ]} : memref <?x?xf32 >, vector <[8 ]x[8 ]xf32 >
406
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args (%acc = %0 ) -> (vector <[8 ]x[8 ]xf32 >) {
407
+ %1 = arith.addf %acc , %acc : vector <[8 ]x[8 ]xf32 >
408
+ scf.yield %1 : vector <[8 ]x[8 ]xf32 >
409
+ }
410
+ vector.transfer_write %x , %buffer [%c0 , %c0 ], %mask {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
411
+ return
412
+ }
413
+
414
+ // Negative test, the padding does not match the constant splat, so we can't
415
+ // forward the store.
416
+ // CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative_0
417
+ // CHECK: vector.transfer_write
418
+ // CHECK: vector.transfer_read
419
+ // CHECK: scf.for
420
+ // CHECK: }
421
+ // CHECK: vector.transfer_write
422
+ // CHECK: return
423
+ func.func @forward_dead_constant_splat_store_with_masking_negative_0 (%buffer : memref <?x?xf32 >, %mask: vector <[8 ]x[8 ]xi1 >) {
424
+ %zero_splat = arith.constant dense <0.0 > : vector <[8 ]x[8 ]xf32 >
425
+ %read_padding = arith.constant 1.0 : f32
426
+ %c1 = arith.constant 1 : index
427
+ %c0 = arith.constant 0 : index
428
+ %c512 = arith.constant 512 : index
429
+ vector.transfer_write %zero_splat , %buffer [%c0 , %c0 ], %mask {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
430
+ %0 = vector.transfer_read %buffer [%c0 , %c0 ], %read_padding , %mask {in_bounds = [true , true ]} : memref <?x?xf32 >, vector <[8 ]x[8 ]xf32 >
431
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args (%acc = %0 ) -> (vector <[8 ]x[8 ]xf32 >) {
432
+ %1 = arith.addf %acc , %acc : vector <[8 ]x[8 ]xf32 >
433
+ scf.yield %1 : vector <[8 ]x[8 ]xf32 >
434
+ }
435
+ vector.transfer_write %x , %buffer [%c0 , %c0 ], %mask {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
436
+ return
437
+ }
438
+
439
+ // Negative test, the masks don't match between the read and write, so we can't
440
+ // foward the store.
441
+ // CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative_1
442
+ // CHECK: vector.transfer_write
443
+ // CHECK: vector.transfer_read
444
+ // CHECK: scf.for
445
+ // CHECK: }
446
+ // CHECK: vector.transfer_write
447
+ // CHECK: return
448
+ func.func @forward_dead_constant_splat_store_with_masking_negative_1 (%buffer : memref <?x?xf32 >, %mask_a: vector <[8 ]x[8 ]xi1 >, %mask_b: vector <[8 ]x[8 ]xi1 >) {
449
+ %zero_splat = arith.constant dense <0.0 > : vector <[8 ]x[8 ]xf32 >
450
+ %read_padding = arith.constant 0.0 : f32
451
+ %c1 = arith.constant 1 : index
452
+ %c0 = arith.constant 0 : index
453
+ %c512 = arith.constant 512 : index
454
+ vector.transfer_write %zero_splat , %buffer [%c0 , %c0 ], %mask_a {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
455
+ %0 = vector.transfer_read %buffer [%c0 , %c0 ], %read_padding , %mask_b {in_bounds = [true , true ]} : memref <?x?xf32 >, vector <[8 ]x[8 ]xf32 >
456
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args (%acc = %0 ) -> (vector <[8 ]x[8 ]xf32 >) {
457
+ %1 = arith.addf %acc , %acc : vector <[8 ]x[8 ]xf32 >
458
+ scf.yield %1 : vector <[8 ]x[8 ]xf32 >
459
+ }
460
+ vector.transfer_write %x , %buffer [%c0 , %c0 ], %mask_a {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
461
+ return
462
+ }
463
+
464
+ // Negative test, here the write is masked but the read is unmasked. We can't
465
+ // forward the store (as the write could be of less elements then the read).
466
+ // CHECK-LABEL: func @forward_dead_constant_splat_store_with_masking_negative_3
467
+ // CHECK: vector.transfer_write
468
+ // CHECK: vector.transfer_read
469
+ // CHECK: scf.for
470
+ // CHECK: }
471
+ // CHECK: vector.transfer_write
472
+ // CHECK: return
473
+ func.func @forward_dead_constant_splat_store_with_masking_negative_3 (%buffer : memref <?x?xf32 >, %mask: vector <[8 ]x[8 ]xi1 >) {
474
+ %zero_splat = arith.constant dense <0.0 > : vector <[8 ]x[8 ]xf32 >
475
+ %read_padding = arith.constant 0.0 : f32
476
+ %c1 = arith.constant 1 : index
477
+ %c0 = arith.constant 0 : index
478
+ %c512 = arith.constant 512 : index
479
+ vector.transfer_write %zero_splat , %buffer [%c0 , %c0 ], %mask {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
480
+ %0 = vector.transfer_read %buffer [%c0 , %c0 ], %read_padding {in_bounds = [true , true ]} : memref <?x?xf32 >, vector <[8 ]x[8 ]xf32 >
481
+ %x = scf.for %arg2 = %c0 to %c512 step %c1 iter_args (%acc = %0 ) -> (vector <[8 ]x[8 ]xf32 >) {
482
+ %1 = arith.addf %acc , %acc : vector <[8 ]x[8 ]xf32 >
483
+ scf.yield %1 : vector <[8 ]x[8 ]xf32 >
484
+ }
485
+ vector.transfer_write %x , %buffer [%c0 , %c0 ], %mask {in_bounds = [true , true ]} : vector <[8 ]x[8 ]xf32 >, memref <?x?xf32 >
486
+ return
487
+ }
0 commit comments