@@ -91,3 +91,58 @@ func.func @partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8x
9191// CHECK: %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]], %[[VAL_2]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
9292// CHECK: return %[[VAL_3]] : tensor<1x4x8x8xf32>
9393// CHECK: }
94+
95+ // -----
96+
97+ // CHECK-LABEL: test_fold_small_const_concat
98+ func.func @test_fold_small_const_concat () -> tensor <6 xi8 > {
99+ // CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi8>}> : () -> tensor<6xi8>
100+ // CHECK: return %[[VAL_0]] : tensor<6xi8>
101+ %0 = " tosa.const" () <{values = dense <[1 , 2 ]> : tensor <2 xi8 >}> : () -> tensor <2 xi8 >
102+ %1 = " tosa.const" () <{values = dense <[3 , 4 , 5 ]> : tensor <3 xi8 >}> : () -> tensor <3 xi8 >
103+ %2 = " tosa.const" () <{values = dense <6 > : tensor <1 xi8 >}> : () -> tensor <1 xi8 >
104+ %3 = " tosa.concat" (%0 , %1 , %2 ) <{axis = 0 : i32 }> : (tensor <2 xi8 >, tensor <3 xi8 >, tensor <1 xi8 >) -> tensor <6 xi8 >
105+ func.return %3 : tensor <6 xi8 >
106+ }
107+
108+ // -----
109+
110+ // CHECK-LABEL: test_no_fold_small_const_concat_with_non_const
111+ func.func @test_no_fold_small_const_concat_with_non_const (%arg0: tensor <2 xi8 >, %arg1: tensor <3 xi8 >, %arg2: tensor <1 xi8 >) -> tensor <6 xi8 > {
112+ // CHECK: %[[VAL_3:.*]] = tosa.concat %arg0, %arg1, %arg2 {axis = 0 : i32} : (tensor<2xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<6xi8>
113+ // CHECK: return %[[VAL_3]] : tensor<6xi8>
114+ %1 = " tosa.concat" (%arg0 , %arg1 , %arg2 ) <{axis = 0 : i32 }> : (tensor <2 xi8 >, tensor <3 xi8 >, tensor <1 xi8 >) -> tensor <6 xi8 >
115+ func.return %1 : tensor <6 xi8 >
116+ }
117+
118+ // -----
119+
120+ // CHECK-LABEL: test_no_fold_small_const_concat_with_higher_dim
121+ func.func @test_no_fold_small_const_concat_with_higher_dim () -> tensor <7 xi8 > {
122+ // CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<[1, 2, 3]> : tensor<3xi8>}> : () -> tensor<3xi8>
123+ // CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<[4, 5, 6]> : tensor<3xi8>}> : () -> tensor<3xi8>
124+ // CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<7> : tensor<1xi8>}> : () -> tensor<1xi8>
125+ // CHECK-DAG: %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] {axis = 0 : i32} : (tensor<3xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<7xi8>
126+ // CHECK: return %[[VAL_3]] : tensor<7xi8>
127+ %0 = " tosa.const" () <{values = dense <[1 , 2 , 3 ]> : tensor <3 xi8 >}> : () -> tensor <3 xi8 >
128+ %1 = " tosa.const" () <{values = dense <[4 , 5 , 6 ]> : tensor <3 xi8 >}> : () -> tensor <3 xi8 >
129+ %2 = " tosa.const" () <{values = dense <7 > : tensor <1 xi8 >}> : () -> tensor <1 xi8 >
130+ %3 = " tosa.concat" (%0 , %1 , %2 ) <{axis = 0 : i32 }> : (tensor <3 xi8 >, tensor <3 xi8 >, tensor <1 xi8 >) -> tensor <7 xi8 >
131+ func.return %3 : tensor <7 xi8 >
132+ }
133+
134+ // -----
135+
136+ // CHECK-LABEL: test_no_fold_small_const_concat_with_higher_rank
137+ func.func @test_no_fold_small_const_concat_with_higher_rank () -> tensor <1 x6 xi8 > {
138+ // CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<{{\[\[}}1, 2]]> : tensor<1x2xi8>}> : () -> tensor<1x2xi8>
139+ // CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<{{\[\[}}3, 4, 5]]> : tensor<1x3xi8>}> : () -> tensor<1x3xi8>
140+ // CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<6> : tensor<1x1xi8>}> : () -> tensor<1x1xi8>
141+ // CHECK-DAG: %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] {axis = 1 : i32} : (tensor<1x2xi8>, tensor<1x3xi8>, tensor<1x1xi8>) -> tensor<1x6xi8>
142+ // CHECK: return %[[VAL_3]] : tensor<1x6xi8>
143+ %0 = " tosa.const" () <{values = dense <[[1 , 2 ]]> : tensor <1 x2 xi8 >}> : () -> tensor <1 x2 xi8 >
144+ %1 = " tosa.const" () <{values = dense <[[3 , 4 , 5 ]]> : tensor <1 x3 xi8 >}> : () -> tensor <1 x3 xi8 >
145+ %2 = " tosa.const" () <{values = dense <[[6 ]]> : tensor <1 x1 xi8 >}> : () -> tensor <1 x1 xi8 >
146+ %3 = " tosa.concat" (%0 , %1 , %2 ) <{axis = 1 : i32 }> : (tensor <1 x2 xi8 >, tensor <1 x3 xi8 >, tensor <1 x1 xi8 >) -> tensor <1 x6 xi8 >
147+ func.return %3 : tensor <1 x6 xi8 >
148+ }
0 commit comments