@@ -62,6 +62,43 @@ func.func @nested_fold(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> {
6262
6363// -----
6464
65+ func.func @concat_multiple_users (%arg0: tensor <1 x1 x7 x7 xf32 >, %arg1: tensor <1 x1 x7 x7 xf32 >) -> (tensor <1 x3 x7 x7 xf32 >, tensor <1 x2 x7 x7 xf32 >) {
66+ %tmp = tensor.empty () : tensor <1 x1 x7 x7 xf32 >
67+ %0 = tosa.concat %arg0 , %arg1 {axis = 1 : i32 } : (tensor <1 x1 x7 x7 xf32 >, tensor <1 x1 x7 x7 xf32 >) -> tensor <1 x2 x7 x7 xf32 >
68+ %1 = tosa.concat %tmp , %0 {axis = 1 : i32 } : (tensor <1 x1 x7 x7 xf32 >, tensor <1 x2 x7 x7 xf32 >) -> tensor <1 x3 x7 x7 xf32 >
69+ %2 = tosa.add %0 , %0 : (tensor <1 x2 x7 x7 xf32 >, tensor <1 x2 x7 x7 xf32 >) -> tensor <1 x2 x7 x7 xf32 >
70+ return %1 , %2 : tensor <1 x3 x7 x7 xf32 >, tensor <1 x2 x7 x7 xf32 >
71+ }
72+
73+ // CHECK-LABEL: func.func @concat_multiple_users
74+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x7x7xf32>, [[PARAM_1_:%.+]]: tensor<1x1x7x7xf32>) -> (tensor<1x3x7x7xf32>, tensor<1x2x7x7xf32>) {
75+ // CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<1x1x7x7xf32>
76+ // CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
77+ // CHECK-NOT: separator of consecutive DAGs
78+ // CHECK-DAG: [[VAR_2_:%.+]] = tosa.concat [[VAR_0_]], [[VAR_1_]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x3x7x7xf32>
79+ // CHECK-DAG: [[VAR_3_:%.+]] = tosa.add [[VAR_1_]], [[VAR_1_]] : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xf32>
80+ // CHECK: return [[VAR_2_]], [[VAR_3_]] : tensor<1x3x7x7xf32>, tensor<1x2x7x7xf32>
81+ // CHECK: }
82+
83+ // -----
84+
85+ func.func @concat_diamond_shape (%arg0: tensor <1 x1 x7 x7 xf32 >, %arg1: tensor <1 x1 x7 x7 xf32 >, %arg2: tensor <1 x1 x7 x7 xf32 >, %arg3: tensor <1 x1 x7 x7 xf32 >) -> tensor <1 x6 x7 x7 xf32 > {
86+ %tmp = tensor.empty () : tensor <1 x1 x7 x7 xf32 >
87+ %0 = tosa.concat %arg0 , %arg1 {axis = 1 : i32 } : (tensor <1 x1 x7 x7 xf32 >, tensor <1 x1 x7 x7 xf32 >) -> tensor <1 x2 x7 x7 xf32 >
88+ %1 = tosa.concat %0 , %arg2 {axis = 1 : i32 } : (tensor <1 x2 x7 x7 xf32 >, tensor <1 x1 x7 x7 xf32 >) -> tensor <1 x3 x7 x7 xf32 >
89+ %2 = tosa.concat %0 , %arg3 {axis = 1 : i32 } : (tensor <1 x2 x7 x7 xf32 >, tensor <1 x1 x7 x7 xf32 >) -> tensor <1 x3 x7 x7 xf32 >
90+ %3 = tosa.concat %1 , %2 {axis = 1 : i32 } : (tensor <1 x3 x7 x7 xf32 >, tensor <1 x3 x7 x7 xf32 >) -> tensor <1 x6 x7 x7 xf32 >
91+ return %3 : tensor <1 x6 x7 x7 xf32 >
92+ }
93+
94+ // CHECK-LABEL: func.func @concat_diamond_shape
95+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x7x7xf32>, [[PARAM_1_:%.+]]: tensor<1x1x7x7xf32>, [[PARAM_2_:%.+]]: tensor<1x1x7x7xf32>, [[PARAM_3_:%.+]]: tensor<1x1x7x7xf32>) -> tensor<1x6x7x7xf32> {
96+ // CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_0_]], [[PARAM_1_]], [[PARAM_3_]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x6x7x7xf32>
97+ // CHECK: return [[VAR_0_]] : tensor<1x6x7x7xf32>
98+ // CHECK: }
99+
100+ // -----
101+
65102func.func @wide_fold (%arg0: tensor <1 x1 x7 x7 xf32 >, %arg1: tensor <1 x1 x7 x7 xf32 >) -> tensor <1 x4 x7 x7 xf32 > {
66103 %0 = tosa.concat %arg0 , %arg0 {axis = 1 : i32 } : (tensor <1 x1 x7 x7 xf32 >, tensor <1 x1 x7 x7 xf32 >) -> tensor <1 x2 x7 x7 xf32 >
67104 %1 = tosa.concat %arg1 , %arg1 {axis = 1 : i32 } : (tensor <1 x1 x7 x7 xf32 >, tensor <1 x1 x7 x7 xf32 >) -> tensor <1 x2 x7 x7 xf32 >
@@ -91,4 +128,4 @@ func.func @partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8x
91128// CHECK: [[VAR_1_:%.+]] = tosa.tile [[PARAM_1_]], [[VAR_0_]] : (tensor<1x2x4x8xf32>, !tosa.shape<4>) -> tensor<1x2x8x8xf32>
92129// CHECK: [[VAR_2_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_0_]], [[VAR_1_]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
93130// CHECK: return [[VAR_2_]] : tensor<1x4x8x8xf32>
94- // CHECK: }
131+ // CHECK: }
0 commit comments