@@ -868,14 +868,16 @@ func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>)
868868// -----
869869
870870// CHECK-LABEL: fold_vector_transfer_masks
871- func.func @fold_vector_transfer_masks (%A: memref <?x?xf32 >) -> (vector <4 x8 xf32 >) {
871+ func.func @fold_vector_transfer_masks (%A: memref <?x?xf32 >) -> (vector <4 x8 xf32 >, vector < 4 x[ 4 ]x f32 > ) {
872872 // CHECK: %[[C0:.+]] = arith.constant 0 : index
873873 %c0 = arith.constant 0 : index
874874 // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
875875 %f0 = arith.constant 0.0 : f32
876876
877877 %mask = vector.constant_mask [8 , 4 ] : vector <8 x4 xi1 >
878878
879+ %arith_all_true_mask = arith.constant dense <true > : vector <4 x[4 ]xi1 >
880+
879881 // CHECK: vector.transfer_read %{{.*}}, %[[F0]] {permutation_map
880882 %1 = vector.transfer_read %A [%c0 , %c0 ], %f0 , %mask
881883 {permutation_map = affine_map <(d0 , d1 ) -> (d1 , d0 )>} : memref <?x?xf32 >, vector <4 x8 xf32 >
@@ -884,8 +886,14 @@ func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>)
884886 vector.transfer_write %1 , %A [%c0 , %c0 ], %mask
885887 {permutation_map = affine_map <(d0 , d1 ) -> (d1 , d0 )>} : vector <4 x8 xf32 >, memref <?x?xf32 >
886888
889+ // CHECK: vector.transfer_read %{{.*}}, %[[F0]] :
890+ %2 = vector.transfer_read %A [%c0 , %c0 ], %f0 , %arith_all_true_mask : memref <?x?xf32 >, vector <4 x[4 ]xf32 >
891+
892+ // CHECK: vector.transfer_write {{.*}}[%[[C0]], %[[C0]]] :
893+ vector.transfer_write %2 , %A [%c0 , %c0 ], %arith_all_true_mask : vector <4 x[4 ]xf32 >, memref <?x?xf32 >
894+
887895 // CHECK: return
888- return %1 : vector <4 x8 xf32 >
896+ return %1 , %2 : vector <4 x8 xf32 >, vector < 4 x[ 4 ]x f32 >
889897}
890898
891899// -----
0 commit comments