55// CHECK-LABEL: @eliminate_redundant_masks_through_insert_and_extracts
66// CHECK: %[[ALL_TRUE_MASK:.*]] = vector.constant_mask [4] : vector<[4]xi1>
77// CHECK: vector.transfer_read {{.*}} %[[ALL_TRUE_MASK]]
8+ // CHECK: vector.mask %[[ALL_TRUE_MASK:.*]] {
9+ // CHECK-SAME: vector.outerproduct
810// CHECK: vector.transfer_write {{.*}} %[[ALL_TRUE_MASK]]
9- func.func @eliminate_redundant_masks_through_insert_and_extracts (%tensor: tensor <1 x1000 xf32 >) {
10- %c0 = arith.constant 0 : index
11+ #map = affine_map <()[s0 ] -> (-(1080 mod s0 ) + 1080 )>
12+
13+ func.func @eliminate_redundant_masks_through_insert_and_extracts (%tensor: tensor <1 x1000 xf32 >, %rhs : f32 ) {
1114 %c4 = arith.constant 4 : index
12- %c1000 = arith.constant 1000 : index
13- %c0_f32 = arith.constant 0.0 : f32
1415 %vscale = vector.vscale
1516 %c4_vscale = arith.muli %vscale , %c4 : index
17+ %ub = affine.apply #map ()[%c4_vscale ]
18+
19+ %c0 = arith.constant 0 : index
20+ %c1000 = arith.constant 1000 : index
21+ %c0_f32 = arith.constant 0.0 : f32
1622 %extracted_slice_0 = tensor.extract_slice %tensor [0 , 0 ] [1 , %c4_vscale ] [1 , 1 ] : tensor <1 x1000 xf32 > to tensor <1 x?xf32 >
17- %output_tensor = scf.for %i = %c0 to %c1000 step %c4_vscale iter_args (%arg = %extracted_slice_0 ) -> tensor <1 x?xf32 > {
23+ %output_tensor = scf.for %i = %c0 to %ub step %c4_vscale iter_args (%arg = %extracted_slice_0 ) -> tensor <1 x?xf32 > {
1824 // 1. Extract a slice.
1925 %extracted_slice_1 = tensor.extract_slice %arg [0 , %i ] [1 , %c4_vscale ] [1 , 1 ] : tensor <1 x?xf32 > to tensor <?xf32 >
2026
@@ -23,8 +29,8 @@ func.func @eliminate_redundant_masks_through_insert_and_extracts(%tensor: tensor
2329 %mask = vector.create_mask %dim_1 : vector <[4 ]xi1 >
2430
2531 // 3. Read the slice and do some computation.
26- %vec = vector.transfer_read %extracted_slice_1 [%c0 ], %c0_f32 , %mask {in_bounds = [true ]} : tensor <?xf32 >, vector <[4 ]xf32 >
27- %new_vec = " test.some_computation " ( %vec ) : ( vector <[4 ]xf32 >) -> ( vector <[4 ]xf32 >)
32+ %lhs = vector.transfer_read %extracted_slice_1 [%c0 ], %c0_f32 , %mask {in_bounds = [true ]} : tensor <?xf32 >, vector <[4 ]xf32 >
33+ %new_vec = vector.mask %mask { vector.outerproduct %lhs , %rhs { kind = #vector.kind < add >} : vector <[4 ]xf32 >, f32 } : vector <[ 4 ]x i1 > -> vector <[4 ]xf32 >
2834
2935 // 4. Write the new value.
3036 %write = vector.transfer_write %new_vec , %extracted_slice_1 [%c0 ], %mask {in_bounds = [true ]} : vector <[4 ]xf32 >, tensor <?xf32 >
0 commit comments