@@ -829,6 +829,98 @@ func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 :
829829
830830// -----
831831
832+ // CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_start_overlap
833+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x2xf32> {
834+ // CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32>
835+ // CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array<i64: 1, 12, 12, 2>, start = array<i64: 0, 0, 0, 1>} : (tensor<1x12x12x4xf32>) -> tensor<1x12x12x2xf32>
836+ // CHECK: return [[VAR_1_]] : tensor<1x12x12x2xf32>
837+ func.func @canonicalize_concat_slice_partial_concat_start_overlap (%arg0 : tensor <1 x12 x12 x2 xf32 >, %arg1 : tensor <1 x12 x12 x2 xf32 >, %arg2 : tensor <1 x12 x12 x2 xf32 >) -> tensor <1 x12 x12 x2 xf32 > {
838+ %0 = tosa.concat %arg0 , %arg1 , %arg2 {axis = 3 : i32 } : (tensor <1 x12 x12 x2 xf32 >, tensor <1 x12 x12 x2 xf32 >, tensor <1 x12 x12 x2 xf32 >) -> tensor <1 x12 x12 x6 xf32 >
839+ %1 = tosa.slice %0 {size = array<i64 : 1 , 12 , 12 , 2 >, start = array<i64 : 0 , 0 , 0 , 1 >} : (tensor <1 x12 x12 x6 xf32 >) -> tensor <1 x12 x12 x2 xf32 >
840+ return %1 : tensor <1 x12 x12 x2 xf32 >
841+ }
842+
843+ // -----
844+
845+ // CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_end_overlap
846+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x2xf32> {
847+ // CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32>
848+ // CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array<i64: 1, 12, 12, 2>, start = array<i64: 0, 0, 0, 1>} : (tensor<1x12x12x4xf32>) -> tensor<1x12x12x2xf32>
849+ // CHECK: return [[VAR_1_]] : tensor<1x12x12x2xf32>
850+ func.func @canonicalize_concat_slice_partial_concat_end_overlap (%arg0 : tensor <1 x12 x12 x2 xf32 >, %arg1 : tensor <1 x12 x12 x2 xf32 >, %arg2 : tensor <1 x12 x12 x2 xf32 >) -> tensor <1 x12 x12 x2 xf32 > {
851+ %0 = tosa.concat %arg0 , %arg1 , %arg2 {axis = 3 : i32 } : (tensor <1 x12 x12 x2 xf32 >, tensor <1 x12 x12 x2 xf32 >, tensor <1 x12 x12 x2 xf32 >) -> tensor <1 x12 x12 x6 xf32 >
852+ %1 = tosa.slice %0 {size = array<i64 : 1 , 12 , 12 , 2 >, start = array<i64 : 0 , 0 , 0 , 3 >} : (tensor <1 x12 x12 x6 xf32 >) -> tensor <1 x12 x12 x2 xf32 >
853+ return %1 : tensor <1 x12 x12 x2 xf32 >
854+ }
855+
856+ // -----
857+
858+ // CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_all_overlap
859+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32> {
860+ // CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32>
861+ // CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array<i64: 1, 12, 12, 4>, start = array<i64: 0, 0, 0, 1>} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x4xf32>
862+ // CHECK: return [[VAR_1_]] : tensor<1x12x12x4xf32>
863+ func.func @canonicalize_concat_slice_partial_concat_all_overlap (%arg0 : tensor <1 x12 x12 x2 xf32 >, %arg1 : tensor <1 x12 x12 x2 xf32 >, %arg2 : tensor <1 x12 x12 x2 xf32 >) -> tensor <1 x12 x12 x4 xf32 > {
864+ %0 = tosa.concat %arg0 , %arg1 , %arg2 {axis = 3 : i32 } : (tensor <1 x12 x12 x2 xf32 >, tensor <1 x12 x12 x2 xf32 >, tensor <1 x12 x12 x2 xf32 >) -> tensor <1 x12 x12 x6 xf32 >
865+ %1 = tosa.slice %0 {size = array<i64 : 1 , 12 , 12 , 4 >, start = array<i64 : 0 , 0 , 0 , 1 >} : (tensor <1 x12 x12 x6 xf32 >) -> tensor <1 x12 x12 x4 xf32 >
866+ return %1 : tensor <1 x12 x12 x4 xf32 >
867+ }
868+
869+ // -----
870+
871+ // CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_multi_use
872+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> (tensor<1x12x12x6xf32>, tensor<1x12x12x2xf32>) {
873+ // CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32>
874+ // CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array<i64: 1, 12, 12, 2>, start = array<i64: 0, 0, 0, 1>} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x2xf32>
875+ // CHECK: return [[VAR_0_]], [[VAR_1_]] : tensor<1x12x12x6xf32>, tensor<1x12x12x2xf32>
876+ func.func @canonicalize_concat_slice_partial_concat_multi_use (%arg0 : tensor <1 x12 x12 x2 xf32 >, %arg1 : tensor <1 x12 x12 x2 xf32 >, %arg2 : tensor <1 x12 x12 x2 xf32 >) -> (tensor <1 x12 x12 x6 xf32 >, tensor <1 x12 x12 x2 xf32 >) {
877+ %0 = tosa.concat %arg0 , %arg1 , %arg2 {axis = 3 : i32 } : (tensor <1 x12 x12 x2 xf32 >, tensor <1 x12 x12 x2 xf32 >, tensor <1 x12 x12 x2 xf32 >) -> tensor <1 x12 x12 x6 xf32 >
878+ %1 = tosa.slice %0 {size = array<i64 : 1 , 12 , 12 , 2 >, start = array<i64 : 0 , 0 , 0 , 1 >} : (tensor <1 x12 x12 x6 xf32 >) -> tensor <1 x12 x12 x2 xf32 >
879+ return %0 , %1 : tensor <1 x12 x12 x6 xf32 >, tensor <1 x12 x12 x2 xf32 >
880+ }
881+
882+ // -----
883+
884+ // CHECK-LABEL: func.func @canonicalize_concat_slice_zero_dim
885+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x0xf32> {
886+ // CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32>
887+ // CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array<i64: 1, 12, 12, 0>, start = array<i64: 0, 0, 0, 0>} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x0xf32>
888+ // CHECK: return [[VAR_1_]] : tensor<1x12x12x0xf32>
889+ // CHECK: }
890+ func.func @canonicalize_concat_slice_zero_dim (%arg0 : tensor <1 x12 x12 x2 xf32 >, %arg1 : tensor <1 x12 x12 x2 xf32 >, %arg2 : tensor <1 x12 x12 x2 xf32 >) -> tensor <1 x12 x12 x0 xf32 > {
891+ %0 = tosa.concat %arg0 , %arg1 , %arg2 {axis = 3 : i32 } : (tensor <1 x12 x12 x2 xf32 >, tensor <1 x12 x12 x2 xf32 >, tensor <1 x12 x12 x2 xf32 >) -> tensor <1 x12 x12 x6 xf32 >
892+ %1 = tosa.slice %0 {size = array<i64 : 1 , 12 , 12 , 0 >, start = array<i64 : 0 , 0 , 0 , 0 >} : (tensor <1 x12 x12 x6 xf32 >) -> tensor <1 x12 x12 x0 xf32 >
893+ return %1 : tensor <1 x12 x12 x0 xf32 >
894+ }
895+
896+ // -----
897+
898+ // CHECK-LABEL: func.func @canonicalize_tile_slice
899+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10xf32>) -> tensor<1x120x12x10x16xf32> {
900+ // CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array<i64: 1, 10, 2, 2, 3>} : (tensor<1x12x12x10x10xf32>) -> tensor<1x120x24x20x30xf32>
901+ // CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array<i64: 1, 120, 12, 10, 16>, start = array<i64: 0, 0, 1, 1, 8>} : (tensor<1x120x24x20x30xf32>) -> tensor<1x120x12x10x16xf32>
902+ // CHECK: return [[VAR_1_]] : tensor<1x120x12x10x16xf32>
903+ func.func @canonicalize_tile_slice (%arg0 : tensor <1 x12 x12 x10 x10 xf32 >) -> tensor <1 x120 x12 x10 x16 xf32 > {
904+ %0 = tosa.tile %arg0 {multiples = array<i64 : 10 , 10 , 10 , 10 , 10 >} : (tensor <1 x12 x12 x10 x10 xf32 >) -> tensor <10 x120 x120 x100 x100 xf32 >
905+ %1 = tosa.slice %0 {size = array<i64 : 1 , 120 , 12 , 10 , 16 >, start = array<i64 : 0 , 0 , 1 , 1 , 18 >} : (tensor <10 x120 x120 x100 x100 xf32 >) -> tensor <1 x120 x12 x10 x16 xf32 >
906+ return %1 : tensor <1 x120 x12 x10 x16 xf32 >
907+ }
908+
909+ // -----
910+
911+ // CHECK-LABEL: func.func @canonicalize_tile_slice_multi_output
912+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10xf32>) -> (tensor<10x120x120x100x100xf32>, tensor<1x12x12x10x16xf32>) {
913+ // CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array<i64: 10, 10, 10, 10, 10>} : (tensor<1x12x12x10x10xf32>) -> tensor<10x120x120x100x100xf32>
914+ // CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array<i64: 1, 12, 12, 10, 16>, start = array<i64: 0, 0, 1, 1, 18>} : (tensor<10x120x120x100x100xf32>) -> tensor<1x12x12x10x16xf32>
915+ // CHECK: return [[VAR_0_]], [[VAR_1_]] : tensor<10x120x120x100x100xf32>, tensor<1x12x12x10x16xf32>
916+ func.func @canonicalize_tile_slice_multi_output (%arg0 : tensor <1 x12 x12 x10 x10 xf32 >) -> (tensor <10 x120 x120 x100 x100 xf32 >, tensor <1 x12 x12 x10 x16 xf32 >) {
917+ %0 = tosa.tile %arg0 {multiples = array<i64 : 10 , 10 , 10 , 10 , 10 >} : (tensor <1 x12 x12 x10 x10 xf32 >) -> tensor <10 x120 x120 x100 x100 xf32 >
918+ %1 = tosa.slice %0 {size = array<i64 : 1 , 12 , 12 , 10 , 16 >, start = array<i64 : 0 , 0 , 1 , 1 , 18 >} : (tensor <10 x120 x120 x100 x100 xf32 >) -> tensor <1 x12 x12 x10 x16 xf32 >
919+ return %0 , %1 : tensor <10 x120 x120 x100 x100 xf32 >, tensor <1 x12 x12 x10 x16 xf32 >
920+ }
921+
922+ // -----
923+
832924// CHECK-LABEL: @canonicalize_optimize_sqrt_reciprocal
833925func.func @canonicalize_optimize_sqrt_reciprocal (%arg0: tensor <1 x5 x1 x1 xf32 >) -> tensor <1 x5 x1 x1 xf32 > {
834926 // CHECK: %[[RSQRT:.*]] = tosa.rsqrt %arg{{.*}} : (tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32>
0 commit comments