@@ -473,10 +473,10 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
473473func.func @fold_dynamic_subview_with_memref_load_expand_shape (%arg0 : memref <16 x?xf32 , strided <[16 , 1 ]>>, %arg1 : index , %arg2 : index , %sz0: index ) -> f32 {
474474 %c0 = arith.constant 0 : index
475475 %expand_shape = memref.expand_shape %arg0 [[0 , 1 ], [2 , 3 ]] output_shape [1 , 16 , %sz0 , 1 ] : memref <16 x?xf32 , strided <[16 , 1 ]>> into memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
476- %0 = memref.load %expand_shape [%c0 , %arg1 , %arg2 , %c0 ] : memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
476+ %0 = memref.load %expand_shape [%c0 , %arg1 , %arg2 , %c0 ] { nontemporal = true } : memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
477477 return %0 : f32
478478}
479- // CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>>
479+ // CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
480480// CHECK-NEXT: return %[[VAL1]] : f32
481481
482482// -----
@@ -487,11 +487,11 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
487487 %c0 = arith.constant 0 : index
488488 %c1f32 = arith.constant 1.0 : f32
489489 %expand_shape = memref.expand_shape %arg0 [[0 , 1 ], [2 , 3 ]] output_shape [1 , 16 , %sz0 , 1 ] : memref <16 x?xf32 , strided <[16 , 1 ]>> into memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
490- memref.store %c1f32 , %expand_shape [%c0 , %arg1 , %arg2 , %c0 ] : memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
490+ memref.store %c1f32 , %expand_shape [%c0 , %arg1 , %arg2 , %c0 ] { nontemporal = true } : memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
491491 return
492492}
493493// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
494- // CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>>
494+ // CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
495495// CHECK-NEXT: return
496496
497497// -----
@@ -819,29 +819,29 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind
819819
820820// -----
821821
822- func.func @fold_vector_load (
822+ func.func @fold_vector_load_subview (
823823 %arg0 : memref <12 x32 xf32 >, %arg1 : index , %arg2 : index ) -> vector <12 x32 xf32 > {
824824 %0 = memref.subview %arg0 [%arg1 , %arg2 ][1 , 1 ][1 , 1 ] : memref <12 x32 xf32 > to memref <f32 , strided <[], offset : ?>>
825825 %1 = vector.load %0 [] : memref <f32 , strided <[], offset : ?>>, vector <12 x32 xf32 >
826826 return %1 : vector <12 x32 xf32 >
827827}
828828
829- // CHECK: func @fold_vector_load
829+ // CHECK: func @fold_vector_load_subview
830830// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
831831// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
832832// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
833833// CHECK: vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<12x32xf32>, vector<12x32xf32>
834834
835835// -----
836836
837- func.func @fold_vector_maskedload (
837+ func.func @fold_vector_maskedload_subview (
838838 %arg0 : memref <12 x32 xf32 >, %arg1 : index , %arg2 : index , %arg3: vector <32 xi1 >, %arg4: vector <32 xf32 >) -> vector <32 xf32 > {
839839 %0 = memref.subview %arg0 [%arg1 , %arg2 ][1 , 1 ][1 , 1 ] : memref <12 x32 xf32 > to memref <f32 , strided <[], offset : ?>>
840840 %1 = vector.maskedload %0 [], %arg3 , %arg4 : memref <f32 , strided <[], offset : ?>>, vector <32 xi1 >, vector <32 xf32 > into vector <32 xf32 >
841841 return %1 : vector <32 xf32 >
842842}
843843
844- // CHECK: func @fold_vector_maskedload
844+ // CHECK: func @fold_vector_maskedload_subview
845845// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
846846// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
847847// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -851,14 +851,14 @@ func.func @fold_vector_maskedload(
851851
852852// -----
853853
854- func.func @fold_vector_store (
854+ func.func @fold_vector_store_subview (
855855 %arg0 : memref <12 x32 xf32 >, %arg1 : index , %arg2 : index , %arg3: vector <2 x32 xf32 >) -> () {
856856 %0 = memref.subview %arg0 [%arg1 , %arg2 ][1 , 1 ][1 , 1 ] : memref <12 x32 xf32 > to memref <f32 , strided <[], offset : ?>>
857857 vector.store %arg3 , %0 [] : memref <f32 , strided <[], offset : ?>>, vector <2 x32 xf32 >
858858 return
859859}
860860
861- // CHECK: func @fold_vector_store
861+ // CHECK: func @fold_vector_store_subview
862862// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
863863// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
864864// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -868,18 +868,166 @@ func.func @fold_vector_store(
868868
869869// -----
870870
871- func.func @fold_vector_maskedstore (
871+ func.func @fold_vector_maskedstore_subview (
872872 %arg0 : memref <12 x32 xf32 >, %arg1 : index , %arg2 : index , %arg3: vector <32 xi1 >, %arg4: vector <32 xf32 >) -> () {
873873 %0 = memref.subview %arg0 [%arg1 , %arg2 ][1 , 1 ][1 , 1 ] : memref <12 x32 xf32 > to memref <f32 , strided <[], offset : ?>>
874874 vector.maskedstore %0 [], %arg3 , %arg4 : memref <f32 , strided <[], offset : ?>>, vector <32 xi1 >, vector <32 xf32 >
875875 return
876876}
877877
878- // CHECK: func @fold_vector_maskedstore
878+ // CHECK: func @fold_vector_maskedstore_subview
879879// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
880880// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
881881// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
882882// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
883883// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
884884// CHECK: vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32>
885885// CHECK: return
886+
887+ // -----
888+
889+ func.func @fold_vector_load_expand_shape (
890+ %arg0 : memref <32 xf32 >, %arg1 : index ) -> vector <8 xf32 > {
891+ %c0 = arith.constant 0 : index
892+ %0 = memref.expand_shape %arg0 [[0 , 1 ]] output_shape [4 , 8 ] : memref <32 xf32 > into memref <4 x8 xf32 >
893+ %1 = vector.load %0 [%arg1 , %c0 ] {nontemporal = true } : memref <4 x8 xf32 >, vector <8 xf32 >
894+ return %1 : vector <8 xf32 >
895+ }
896+
897+ // CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
898+ // CHECK-LABEL: func @fold_vector_load_expand_shape
899+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
900+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
901+ // CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
902+ // CHECK: vector.load %[[ARG0]][%[[IDX]]] {nontemporal = true}
903+
904+ // -----
905+
906+ func.func @fold_vector_maskedload_expand_shape (
907+ %arg0 : memref <32 xf32 >, %arg1 : index , %arg3: vector <8 xi1 >, %arg4: vector <8 xf32 >) -> vector <8 xf32 > {
908+ %c0 = arith.constant 0 : index
909+ %0 = memref.expand_shape %arg0 [[0 , 1 ]] output_shape [4 , 8 ] : memref <32 xf32 > into memref <4 x8 xf32 >
910+ %1 = vector.maskedload %0 [%arg1 , %c0 ], %arg3 , %arg4 : memref <4 x8 xf32 >, vector <8 xi1 >, vector <8 xf32 > into vector <8 xf32 >
911+ return %1 : vector <8 xf32 >
912+ }
913+
914+ // CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
915+ // CHECK-LABEL: func @fold_vector_maskedload_expand_shape
916+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
917+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
918+ // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
919+ // CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
920+ // CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
921+ // CHECK: vector.maskedload %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]
922+
923+ // -----
924+
925+ func.func @fold_vector_store_expand_shape (
926+ %arg0 : memref <32 xf32 >, %arg1 : index , %val : vector <8 xf32 >) {
927+ %c0 = arith.constant 0 : index
928+ %0 = memref.expand_shape %arg0 [[0 , 1 ]] output_shape [4 , 8 ] : memref <32 xf32 > into memref <4 x8 xf32 >
929+ vector.store %val , %0 [%arg1 , %c0 ] {nontemporal = true } : memref <4 x8 xf32 >, vector <8 xf32 >
930+ return
931+ }
932+
933+ // CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
934+ // CHECK-LABEL: func @fold_vector_store_expand_shape
935+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
936+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
937+ // CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
938+ // CHECK: vector.store %{{.*}}, %[[ARG0]][%[[IDX]]] {nontemporal = true}
939+
940+ // -----
941+
942+ func.func @fold_vector_maskedstore_expand_shape (
943+ %arg0 : memref <32 xf32 >, %arg1 : index , %arg3: vector <8 xi1 >, %arg4: vector <8 xf32 >) {
944+ %c0 = arith.constant 0 : index
945+ %0 = memref.expand_shape %arg0 [[0 , 1 ]] output_shape [4 , 8 ] : memref <32 xf32 > into memref <4 x8 xf32 >
946+ vector.maskedstore %0 [%arg1 , %c0 ], %arg3 , %arg4 : memref <4 x8 xf32 >, vector <8 xi1 >, vector <8 xf32 >
947+ return
948+ }
949+
950+ // CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
951+ // CHECK-LABEL: func @fold_vector_maskedstore_expand_shape
952+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
953+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
954+ // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
955+ // CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
956+ // CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
957+ // CHECK: vector.maskedstore %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]
958+
959+ // -----
960+
961+ func.func @fold_vector_load_collapse_shape (
962+ %arg0 : memref <4 x8 xf32 >, %arg1 : index ) -> vector <8 xf32 > {
963+ %0 = memref.collapse_shape %arg0 [[0 , 1 ]] : memref <4 x8 xf32 > into memref <32 xf32 >
964+ %1 = vector.load %0 [%arg1 ] {nontemporal = true } : memref <32 xf32 >, vector <8 xf32 >
965+ return %1 : vector <8 xf32 >
966+ }
967+
968+ // CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
969+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
970+ // CHECK-LABEL: func @fold_vector_load_collapse_shape
971+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
972+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
973+ // CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
974+ // CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
975+ // CHECK: vector.load %[[ARG0]][%[[IDX]], %[[IDX1]]] {nontemporal = true}
976+
977+ // -----
978+
979+ func.func @fold_vector_maskedload_collapse_shape (
980+ %arg0 : memref <4 x8 xf32 >, %arg1 : index , %arg3: vector <8 xi1 >, %arg4: vector <8 xf32 >) -> vector <8 xf32 > {
981+ %0 = memref.collapse_shape %arg0 [[0 , 1 ]] : memref <4 x8 xf32 > into memref <32 xf32 >
982+ %1 = vector.maskedload %0 [%arg1 ], %arg3 , %arg4 : memref <32 xf32 >, vector <8 xi1 >, vector <8 xf32 > into vector <8 xf32 >
983+ return %1 : vector <8 xf32 >
984+ }
985+
986+ // CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
987+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
988+ // CHECK-LABEL: func @fold_vector_maskedload_collapse_shape
989+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
990+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
991+ // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
992+ // CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
993+ // CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
994+ // CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
995+ // CHECK: vector.maskedload %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]
996+
997+ // -----
998+
999+ func.func @fold_vector_store_collapse_shape (
1000+ %arg0 : memref <4 x8 xf32 >, %arg1 : index , %val : vector <8 xf32 >) {
1001+ %0 = memref.collapse_shape %arg0 [[0 , 1 ]] : memref <4 x8 xf32 > into memref <32 xf32 >
1002+ vector.store %val , %0 [%arg1 ] {nontemporal = true } : memref <32 xf32 >, vector <8 xf32 >
1003+ return
1004+ }
1005+
1006+ // CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
1007+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
1008+ // CHECK-LABEL: func @fold_vector_store_collapse_shape
1009+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
1010+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
1011+ // CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
1012+ // CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
1013+ // CHECK: vector.store %{{.*}}, %[[ARG0]][%[[IDX]], %[[IDX1]]] {nontemporal = true}
1014+
1015+ // -----
1016+
1017+ func.func @fold_vector_maskedstore_collapse_shape (
1018+ %arg0 : memref <4 x8 xf32 >, %arg1 : index , %arg3: vector <8 xi1 >, %arg4: vector <8 xf32 >) {
1019+ %0 = memref.collapse_shape %arg0 [[0 , 1 ]] : memref <4 x8 xf32 > into memref <32 xf32 >
1020+ vector.maskedstore %0 [%arg1 ], %arg3 , %arg4 : memref <32 xf32 >, vector <8 xi1 >, vector <8 xf32 >
1021+ return
1022+ }
1023+
1024+ // CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
1025+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
1026+ // CHECK-LABEL: func @fold_vector_maskedstore_collapse_shape
1027+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
1028+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
1029+ // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
1030+ // CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
1031+ // CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
1032+ // CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
1033+ // CHECK: vector.maskedstore %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]
0 commit comments