@@ -693,6 +693,81 @@ func.func @slice_nofold(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
693693
694694// -----
695695
696+ // CHECK-LABEL: @slice_fuse
697+ func.func @slice_fuse (%arg0: tensor <3 x4 xf32 >) -> tensor <1 x2 xf32 > {
698+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<1x2xf32> {
699+ // CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array<i64: 1, 2>, start = array<i64: 0, 0>} : (tensor<3x4xf32>) -> tensor<1x2xf32>
700+ // CHECK: return [[VAR_0_]] : tensor<1x2xf32>
701+ %0 = tosa.slice %arg0 { size = array<i64 : 2 , 3 >, start = array<i64 : 0 , 0 >}: (tensor <3 x4 xf32 >) -> tensor <2 x3 xf32 >
702+ %1 = tosa.slice %0 { size = array<i64 : 1 , 2 >, start = array<i64 : 0 , 0 >}: (tensor <2 x3 xf32 >) -> tensor <1 x2 xf32 >
703+ return %1 : tensor <1 x2 xf32 >
704+ }
705+
706+ // -----
707+
708+ // CHECK-LABEL: @slice_fuse_different_step
709+ func.func @slice_fuse_different_step (%arg0: tensor <3 x4 xf32 >) -> tensor <1 x1 xf32 > {
710+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<1x1xf32> {
711+ // CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array<i64: 1, 1>, start = array<i64: 0, 0>} : (tensor<3x4xf32>) -> tensor<1x1xf32>
712+ // CHECK: return [[VAR_0_]] : tensor<1x1xf32>
713+ %0 = tosa.slice %arg0 { size = array<i64 : 1 , 3 >, start = array<i64 : 0 , 0 >}: (tensor <3 x4 xf32 >) -> tensor <1 x3 xf32 >
714+ %1 = tosa.slice %0 { size = array<i64 : 1 , 1 >, start = array<i64 : 0 , 0 >}: (tensor <1 x3 xf32 >) -> tensor <1 x1 xf32 >
715+ return %1 : tensor <1 x1 xf32 >
716+ }
717+
718+ // -----
719+
720+ // CHECK-LABEL: @slice_fuse_different_start
721+ func.func @slice_fuse_different_start (%arg0: tensor <3 x4 xf32 >) -> tensor <1 x1 xf32 > {
722+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<1x1xf32> {
723+ // CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array<i64: 1, 1>, start = array<i64: 1, 0>} : (tensor<3x4xf32>) -> tensor<1x1xf32>
724+ // CHECK: return [[VAR_0_]] : tensor<1x1xf32>
725+ %0 = tosa.slice %arg0 { size = array<i64 : 1 , 3 >, start = array<i64 : 1 , 0 >}: (tensor <3 x4 xf32 >) -> tensor <1 x3 xf32 >
726+ %1 = tosa.slice %0 { size = array<i64 : 1 , 1 >, start = array<i64 : 0 , 0 >}: (tensor <1 x3 xf32 >) -> tensor <1 x1 xf32 >
727+ return %1 : tensor <1 x1 xf32 >
728+ }
729+
730+ // -----
731+
732+ // CHECK-LABEL: @slice_fuse_different_start_2
733+ func.func @slice_fuse_different_start_2 (%arg0: tensor <10 x10 xf32 >) -> tensor <1 x1 xf32 > {
734+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<1x1xf32> {
735+ // CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array<i64: 1, 1>, start = array<i64: 4, 1>} : (tensor<10x10xf32>) -> tensor<1x1xf32>
736+ // CHECK: return [[VAR_0_]] : tensor<1x1xf32>
737+ %0 = tosa.slice %arg0 { size = array<i64 : 5 , 5 >, start = array<i64 : 4 , 0 >}: (tensor <10 x10 xf32 >) -> tensor <5 x5 xf32 >
738+ %1 = tosa.slice %0 { size = array<i64 : 3 , 3 >, start = array<i64 : 0 , 0 >}: (tensor <5 x5 xf32 >) -> tensor <3 x3 xf32 >
739+ %2 = tosa.slice %1 { size = array<i64 : 1 , 1 >, start = array<i64 : 0 , 1 >}: (tensor <3 x3 xf32 >) -> tensor <1 x1 xf32 >
740+ return %2 : tensor <1 x1 xf32 >
741+ }
742+
743+ // -----
744+
745+ // CHECK-LABEL: @slice_fuse_different_start_3
746+ func.func @slice_fuse_different_start_3 (%arg0: tensor <10 x10 xf32 >) -> tensor <1 x1 xf32 > {
747+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<1x1xf32> {
748+ // CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array<i64: 1, 1>, start = array<i64: 4, 2>} : (tensor<10x10xf32>) -> tensor<1x1xf32>
749+ // CHECK: return [[VAR_0_]] : tensor<1x1xf32>
750+ %0 = tosa.slice %arg0 { size = array<i64 : 5 , 5 >, start = array<i64 : 4 , 1 >}: (tensor <10 x10 xf32 >) -> tensor <5 x5 xf32 >
751+ %1 = tosa.slice %0 { size = array<i64 : 3 , 3 >, start = array<i64 : 0 , 0 >}: (tensor <5 x5 xf32 >) -> tensor <3 x3 xf32 >
752+ %2 = tosa.slice %1 { size = array<i64 : 1 , 1 >, start = array<i64 : 0 , 1 >}: (tensor <3 x3 xf32 >) -> tensor <1 x1 xf32 >
753+ return %2 : tensor <1 x1 xf32 >
754+ }
755+
756+ // -----
757+
758+ // CHECK-LABEL: func.func @slice_fuse_different_start_dynamic
759+ func.func @slice_fuse_different_start_dynamic (%arg0: tensor <*xf32 >) -> tensor <*xf32 > {
760+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> {
761+ // CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array<i64: 1, 1>, start = array<i64: 4, 1>} : (tensor<*xf32>) -> tensor<*xf32>
762+ // CHECK: return [[VAR_0_]] : tensor<*xf32>
763+ %0 = tosa.slice %arg0 { size = array<i64 : 5 , 5 >, start = array<i64 : 4 , 0 >}: (tensor <*xf32 >) -> tensor <*xf32 >
764+ %1 = tosa.slice %0 { size = array<i64 : 3 , 3 >, start = array<i64 : 0 , 0 >}: (tensor <*xf32 >) -> tensor <*xf32 >
765+ %2 = tosa.slice %1 { size = array<i64 : 1 , 1 >, start = array<i64 : 0 , 1 >}: (tensor <*xf32 >) -> tensor <*xf32 >
766+ return %2 : tensor <*xf32 >
767+ }
768+
769+ // -----
770+
696771// CHECK-LABEL: @tile_fold
697772func.func @tile_fold (%arg0: tensor <3 x4 xf32 >) -> tensor <3 x4 xf32 > {
698773 // CHECK: return %arg0
0 commit comments