@@ -513,3 +513,192 @@ func.func @negative_extract_vec_fma(%arg0: vector<4xf32>, %arg1: vector<4xf32>,
513513 %1 = vector.extract %0 [1 ] : f32 from vector <4 xf32 >
514514 return %1 : f32
515515}
516+
517+ //-----------------------------------------------------------------------------
518+ // [Pattern: ExtractOpFromLoad]
519+ //-----------------------------------------------------------------------------
520+
521+ // CHECK-LABEL: @extract_load_scalar
522+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
523+ func.func @extract_load_scalar (%arg0: memref <?xf32 >, %arg1: index ) -> f32 {
524+ // CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>
525+ // CHECK: return %[[RES]] : f32
526+ %0 = vector.load %arg0 [%arg1 ] : memref <?xf32 >, vector <4 xf32 >
527+ %1 = vector.extract %0 [0 ] : f32 from vector <4 xf32 >
528+ return %1 : f32
529+ }
530+
531+ // CHECK-LABEL: @extract_load_scalar_non_zero_off
532+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
533+ func.func @extract_load_scalar_non_zero_off (%arg0: memref <?xf32 >, %arg1: index ) -> f32 {
534+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
535+ // CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
536+ // CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
537+ // CHECK: return %[[RES]] : f32
538+ %0 = vector.load %arg0 [%arg1 ] : memref <?xf32 >, vector <4 xf32 >
539+ %1 = vector.extract %0 [1 ] : f32 from vector <4 xf32 >
540+ return %1 : f32
541+ }
542+
543+ // CHECK-LABEL: @extract_load_scalar_dyn_off
544+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
545+ func.func @extract_load_scalar_dyn_off (%arg0: memref <?xf32 >, %arg1: index , %arg2: index ) -> f32 {
546+ // CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[ARG2]] overflow<nsw> : index
547+ // CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
548+ // CHECK: return %[[RES]] : f32
549+ %0 = vector.load %arg0 [%arg1 ] : memref <?xf32 >, vector <4 xf32 >
550+ %1 = vector.extract %0 [%arg2 ] : f32 from vector <4 xf32 >
551+ return %1 : f32
552+ }
553+
554+ // CHECK-LABEL: @extract_load_vec
555+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
556+ func.func @extract_load_vec (%arg0: memref <?x?xf32 >, %arg1: index , %arg2: index ) -> vector <4 xf32 > {
557+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
558+ // CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
559+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[OFF]], %[[ARG2]]] : memref<?x?xf32>, vector<4xf32>
560+ // CHECK: return %[[RES]] : vector<4xf32>
561+ %0 = vector.load %arg0 [%arg1 , %arg2 ] : memref <?x?xf32 >, vector <2 x4 xf32 >
562+ %1 = vector.extract %0 [1 ] : vector <4 xf32 > from vector <2 x4 xf32 >
563+ return %1 : vector <4 xf32 >
564+ }
565+
566+ // CHECK-LABEL: @extract_load_scalar_high_rank
567+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
568+ func.func @extract_load_scalar_high_rank (%arg0: memref <?x?xf32 >, %arg1: index , %arg2: index ) -> f32 {
569+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
570+ // CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
571+ // CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[OFF]]] : memref<?x?xf32>
572+ // CHECK: return %[[RES]] : f32
573+ %0 = vector.load %arg0 [%arg1 , %arg2 ] : memref <?x?xf32 >, vector <4 xf32 >
574+ %1 = vector.extract %0 [1 ] : f32 from vector <4 xf32 >
575+ return %1 : f32
576+ }
577+
578+ // CHECK-LABEL: @extract_load_vec_high_rank
579+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
580+ func.func @extract_load_vec_high_rank (%arg0: memref <?x?x?xf32 >, %arg1: index , %arg2: index , %arg3: index ) -> vector <4 xf32 > {
581+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
582+ // CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
583+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[OFF]], %[[ARG3]]] : memref<?x?x?xf32>, vector<4xf32>
584+ // CHECK: return %[[RES]] : vector<4xf32>
585+ %0 = vector.load %arg0 [%arg1 , %arg2 , %arg3 ] : memref <?x?x?xf32 >, vector <2 x4 xf32 >
586+ %1 = vector.extract %0 [1 ] : vector <4 xf32 > from vector <2 x4 xf32 >
587+ return %1 : vector <4 xf32 >
588+ }
589+
590+ // CHECK-LABEL: @negative_load_scalar_from_vec_memref
591+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<4xf32>>, %[[ARG1:.*]]: index)
592+ func.func @negative_load_scalar_from_vec_memref (%arg0: memref <?xvector <4 xf32 >>, %arg1: index ) -> f32 {
593+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xvector<4xf32>>, vector<4xf32>
594+ // CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
595+ // CHECK: return %[[EXT]] : f32
596+ %0 = vector.load %arg0 [%arg1 ] : memref <?xvector <4 xf32 >>, vector <4 xf32 >
597+ %1 = vector.extract %0 [0 ] : f32 from vector <4 xf32 >
598+ return %1 : f32
599+ }
600+
601+ // CHECK-LABEL: @negative_extract_load_no_single_use
602+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
603+ func.func @negative_extract_load_no_single_use (%arg0: memref <?xf32 >, %arg1: index ) -> (f32 , vector <4 xf32 >) {
604+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
605+ // CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
606+ // CHECK: return %[[EXT]], %[[RES]] : f32, vector<4xf32>
607+ %0 = vector.load %arg0 [%arg1 ] : memref <?xf32 >, vector <4 xf32 >
608+ %1 = vector.extract %0 [0 ] : f32 from vector <4 xf32 >
609+ return %1 , %0 : f32 , vector <4 xf32 >
610+ }
611+
612+ // CHECK-LABEL: @negative_load_scalable
613+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
614+ func.func @negative_load_scalable (%arg0: memref <?xf32 >, %arg1: index ) -> f32 {
615+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
616+ // CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<[1]xf32>
617+ // CHECK: return %[[EXT]] : f32
618+ %0 = vector.load %arg0 [%arg1 ] : memref <?xf32 >, vector <[1 ]xf32 >
619+ %1 = vector.extract %0 [0 ] : f32 from vector <[1 ]xf32 >
620+ return %1 : f32
621+ }
622+
623+ // CHECK-LABEL: @negative_extract_load_unsupported_ranks
624+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
625+ func.func @negative_extract_load_unsupported_ranks (%arg0: memref <?xf32 >, %arg1: index ) -> vector <4 xf32 > {
626+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<2x4xf32>
627+ // CHECK: %[[EXT:.*]] = vector.extract %[[RES]][1] : vector<4xf32> from vector<2x4xf32>
628+ // CHECK: return %[[EXT]] : vector<4xf32>
629+ %0 = vector.load %arg0 [%arg1 ] : memref <?xf32 >, vector <2 x4 xf32 >
630+ %1 = vector.extract %0 [1 ] : vector <4 xf32 > from vector <2 x4 xf32 >
631+ return %1 : vector <4 xf32 >
632+ }
633+
634+ //-----------------------------------------------------------------------------
635+ // [Pattern: StoreFromSplat]
636+ //-----------------------------------------------------------------------------
637+
638+ // CHECK-LABEL: @store_splat
639+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
640+ func.func @store_splat (%arg0: memref <?xf32 >, %arg1: index , %arg2: f32 ) {
641+ // CHECK: memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
642+ %0 = vector.splat %arg2 : vector <1 xf32 >
643+ vector.store %0 , %arg0 [%arg1 ] : memref <?xf32 >, vector <1 xf32 >
644+ return
645+ }
646+
647+ // CHECK-LABEL: @store_broadcast
648+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
649+ func.func @store_broadcast (%arg0: memref <?xf32 >, %arg1: index , %arg2: f32 ) {
650+ // CHECK: memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
651+ %0 = vector.broadcast %arg2 : f32 to vector <1 xf32 >
652+ vector.store %0 , %arg0 [%arg1 ] : memref <?xf32 >, vector <1 xf32 >
653+ return
654+ }
655+
656+ // CHECK-LABEL: @store_broadcast_1d_2d
657+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<1xf32>)
658+ func.func @store_broadcast_1d_2d (%arg0: memref <?x?xf32 >, %arg1: index , %arg2: index , %arg3: vector <1 xf32 >) {
659+ // CHECK: vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<?x?xf32>, vector<1xf32>
660+ %0 = vector.broadcast %arg3 : vector <1 xf32 > to vector <1 x1 xf32 >
661+ vector.store %0 , %arg0 [%arg1 , %arg2 ] : memref <?x?xf32 >, vector <1 x1 xf32 >
662+ return
663+ }
664+
665+ // CHECK-LABEL: @negative_store_scalable
666+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
667+ func.func @negative_store_scalable (%arg0: memref <?xf32 >, %arg1: index , %arg2: f32 ) {
668+ // CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<[1]xf32>
669+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
670+ %0 = vector.splat %arg2 : vector <[1 ]xf32 >
671+ vector.store %0 , %arg0 [%arg1 ] : memref <?xf32 >, vector <[1 ]xf32 >
672+ return
673+ }
674+
675+ // CHECK-LABEL: @negative_store_vec_memref
676+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<1xf32>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
677+ func.func @negative_store_vec_memref (%arg0: memref <?xvector <1 xf32 >>, %arg1: index , %arg2: f32 ) {
678+ // CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
679+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xvector<1xf32>>, vector<1xf32>
680+ %0 = vector.splat %arg2 : vector <1 xf32 >
681+ vector.store %0 , %arg0 [%arg1 ] : memref <?xvector <1 xf32 >>, vector <1 xf32 >
682+ return
683+ }
684+
685+ // CHECK-LABEL: @negative_store_non_1
686+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
687+ func.func @negative_store_non_1 (%arg0: memref <?xf32 >, %arg1: index , %arg2: f32 ) {
688+ // CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<4xf32>
689+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
690+ %0 = vector.splat %arg2 : vector <4 xf32 >
691+ vector.store %0 , %arg0 [%arg1 ] : memref <?xf32 >, vector <4 xf32 >
692+ return
693+ }
694+
695+ // CHECK-LABEL: @negative_store_no_single_use
696+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
697+ func.func @negative_store_no_single_use (%arg0: memref <?xf32 >, %arg1: index , %arg2: f32 ) -> vector <1 xf32 > {
698+ // CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
699+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<1xf32>
700+ // CHECK: return %[[RES:.*]] : vector<1xf32>
701+ %0 = vector.splat %arg2 : vector <1 xf32 >
702+ vector.store %0 , %arg0 [%arg1 ] : memref <?xf32 >, vector <1 xf32 >
703+ return %0 : vector <1 xf32 >
704+ }
0 commit comments