@@ -1176,6 +1176,52 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
11761176
11771177// -----
11781178
1179+ // CHECK-LABEL: @broadcast_broadcast_fold
1180+ // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
1181+ // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x3xf32>
1182+ // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
1183+ // CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2]
1184+ // CHECK-NOT: linalg.broadcast
1185+ // CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32>
1186+ func.func @broadcast_broadcast_fold (%input: tensor <2 xf32 >,
1187+ %init1: tensor <2 x3 xf32 >,
1188+ %init2: tensor <2 x3 x4 xf32 >) -> tensor <2 x3 x4 xf32 > {
1189+ %broadcast1 = linalg.broadcast
1190+ ins (%input: tensor <2 xf32 >)
1191+ outs (%init1: tensor <2 x3 xf32 >)
1192+ dimensions = [1 ]
1193+ %broadcast2 = linalg.broadcast
1194+ ins (%broadcast1: tensor <2 x3 xf32 >)
1195+ outs (%init2: tensor <2 x3 x4 xf32 >)
1196+ dimensions = [2 ]
1197+ func.return %broadcast2 : tensor <2 x3 x4 xf32 >
1198+ }
1199+
1200+ // -----
1201+
1202+ // CHECK-LABEL: @broadcast_broadcast_fold
1203+ // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
1204+ // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32>
1205+ // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
1206+ // CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2]
1207+ // CHECK-NOT: linalg.broadcast
1208+ // CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32>
1209+ func.func @broadcast_broadcast_fold (%input: tensor <2 xf32 >,
1210+ %init1: tensor <2 x4 xf32 >,
1211+ %init2: tensor <2 x3 x4 xf32 >) -> tensor <2 x3 x4 xf32 > {
1212+ %broadcast1 = linalg.broadcast
1213+ ins (%input: tensor <2 xf32 >)
1214+ outs (%init1: tensor <2 x4 xf32 >)
1215+ dimensions = [1 ]
1216+ %broadcast2 = linalg.broadcast
1217+ ins (%broadcast1: tensor <2 x4 xf32 >)
1218+ outs (%init2: tensor <2 x3 x4 xf32 >)
1219+ dimensions = [1 ]
1220+ func.return %broadcast2 : tensor <2 x3 x4 xf32 >
1221+ }
1222+
1223+ // -----
1224+
11791225func.func @transpose_1d (%input: tensor <16 xf32 >,
11801226 %init: tensor <16 xf32 >) -> tensor <16 xf32 > {
11811227 %transpose = linalg.transpose
0 commit comments