@@ -1041,3 +1041,173 @@ util.func @move_captured_from_above_ops(%arg0 : tensor<1x1x2x4xi32>,
10411041// CHECK: linalg.yield
10421042// CHECK: flow.return %[[GENERIC]]
10431043// CHECK: util.return %[[DISPATCH]]
1044+
1045+ // -----
1046+
1047+ util.func @horizontal_fusion1 (%lhs : tensor <2 x4096 x640 xf16 >,
1048+ %rhs0 : tensor <10 x64 x640 xf16 >, %rhs1 : tensor <10 x64 x640 xf16 >,
1049+ %rhs2 : tensor <10 x64 x640 xf16 >) ->
1050+ (tensor <2 x10 x4096 x64 xf16 >, tensor <2 x10 x4096 x64 xf16 >,
1051+ tensor <2 x10 x4096 x64 xf16 >) {
1052+ %4 = tensor.empty () : tensor <2 x10 x4096 x64 xf32 >
1053+ %cst = arith.constant 0.0 : f32
1054+ %5 = linalg.fill ins (%cst : f32 )
1055+ outs (%4 : tensor <2 x10 x4096 x64 xf32 >) -> tensor <2 x10 x4096 x64 xf32 >
1056+ %6:3 = linalg.generic {
1057+ indexing_maps = [affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d2 , d4 )>,
1058+ affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d1 , d3 , d4 )>,
1059+ affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d1 , d3 , d4 )>,
1060+ affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d1 , d3 , d4 )>,
1061+ affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d2 , d3 )>,
1062+ affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d2 , d3 )>,
1063+ affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d2 , d3 )>],
1064+ iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" , " reduction" ]}
1065+ ins (%lhs , %rhs0 , %rhs1 , %rhs2
1066+ : tensor <2 x4096 x640 xf16 >, tensor <10 x64 x640 xf16 >, tensor <10 x64 x640 xf16 >,
1067+ tensor <10 x64 x640 xf16 >)
1068+ outs (%5 , %5 , %5
1069+ : tensor <2 x10 x4096 x64 xf32 >, tensor <2 x10 x4096 x64 xf32 >, tensor <2 x10 x4096 x64 xf32 >) {
1070+ ^bb0 (%in: f16 , %in_0: f16 , %in_1: f16 , %in_2: f16 , %out: f32 , %out_3: f32 , %out_4: f32 ):
1071+ %14 = arith.extf %in : f16 to f32
1072+ %15 = arith.extf %in_0 : f16 to f32
1073+ %16 = arith.mulf %14 , %15 : f32
1074+ %17 = arith.addf %out , %16 : f32
1075+ %18 = arith.extf %in_1 : f16 to f32
1076+ %19 = arith.mulf %14 , %18 : f32
1077+ %20 = arith.addf %out_3 , %19 : f32
1078+ %21 = arith.extf %in_2 : f16 to f32
1079+ %22 = arith.mulf %14 , %21 : f32
1080+ %23 = arith.addf %out_4 , %22 : f32
1081+ linalg.yield %17 , %20 , %23 : f32 , f32 , f32
1082+ } -> (tensor <2 x10 x4096 x64 xf32 >, tensor <2 x10 x4096 x64 xf32 >, tensor <2 x10 x4096 x64 xf32 >)
1083+ %7 = tensor.empty () : tensor <2 x10 x4096 x64 xf16 >
1084+ %8 = linalg.generic {
1085+ indexing_maps = [affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>,
1086+ affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>],
1087+ iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ]}
1088+ ins (%6#0 : tensor <2 x10 x4096 x64 xf32 >) outs (%7 : tensor <2 x10 x4096 x64 xf16 >) {
1089+ ^bb0 (%in: f32 , %out: f16 ):
1090+ %14 = arith.truncf %in : f32 to f16
1091+ linalg.yield %14 : f16
1092+ } -> tensor <2 x10 x4096 x64 xf16 >
1093+ %9 = linalg.generic {
1094+ indexing_maps = [affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>,
1095+ affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>],
1096+ iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ]}
1097+ ins (%6#1 : tensor <2 x10 x4096 x64 xf32 >) outs (%7 : tensor <2 x10 x4096 x64 xf16 >) {
1098+ ^bb0 (%in: f32 , %out: f16 ):
1099+ %14 = arith.truncf %in : f32 to f16
1100+ linalg.yield %14 : f16
1101+ } -> tensor <2 x10 x4096 x64 xf16 >
1102+ %10 = linalg.generic {
1103+ indexing_maps = [affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>,
1104+ affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>],
1105+ iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ]}
1106+ ins (%6#2 : tensor <2 x10 x4096 x64 xf32 >) outs (%7 : tensor <2 x10 x4096 x64 xf16 >) {
1107+ ^bb0 (%in: f32 , %out: f16 ):
1108+ %14 = arith.truncf %in : f32 to f16
1109+ linalg.yield %14 : f16
1110+ } -> tensor <2 x10 x4096 x64 xf16 >
1111+ util.return %8 , %9 , %10 : tensor <2 x10 x4096 x64 xf16 >, tensor <2 x10 x4096 x64 xf16 >, tensor <2 x10 x4096 x64 xf16 >
1112+ }
1113+ // CHECK-LABEL: func public @horizontal_fusion1
1114+ // CHECK: %[[DISPATCH:.+]]:3 = flow.dispatch.region
1115+ // CHECK: %[[GENERIC:.+]]:3 = linalg.generic
1116+ // CHECK: %[[TRUNC0:.+]] = linalg.generic
1117+ // CHECK-SAME: ins(%[[GENERIC]]#0 :
1118+ // CHECK: %[[TRUNC1:.+]] = linalg.generic
1119+ // CHECK-SAME: ins(%[[GENERIC]]#1 :
1120+ // CHECK: %[[TRUNC2:.+]] = linalg.generic
1121+ // CHECK-SAME: ins(%[[GENERIC]]#2 :
1122+ // CHECK: flow.return %[[TRUNC0]], %[[TRUNC1]], %[[TRUNC2]]
1123+ // CHECK: util.return %[[DISPATCH]]#0, %[[DISPATCH]]#1, %[[DISPATCH]]#2
1124+
1125+ // -----
1126+
1127+ util.func @horizontal_fusion2 (%lhs : tensor <2 x4096 x640 xi8 >,
1128+ %rhs0 : tensor <2 x640 x640 xi8 >, %rhs1 : tensor <2 x640 x640 xi8 >)
1129+ -> tensor <2 x4096 x640 xf16 > {
1130+ %c0_i32 = arith.constant 32 : i32
1131+ %0 = tensor.empty () : tensor <2 x4096 x640 xf16 >
1132+ %1 = tensor.empty () : tensor <2 x4096 x640 xi32 >
1133+ %2 = linalg.fill ins (%c0_i32 : i32 )
1134+ outs (%1 : tensor <2 x4096 x640 xi32 >) -> tensor <2 x4096 x640 xi32 >
1135+ %3:2 = linalg.generic {
1136+ indexing_maps = [affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d3 )>,
1137+ affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d2 , d3 )>,
1138+ affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d2 , d3 )>,
1139+ affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 )>,
1140+ affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 )>],
1141+ iterator_types = [" parallel" , " parallel" , " parallel" , " reduction" ]}
1142+ ins (%lhs , %rhs0 , %rhs1
1143+ : tensor <2 x4096 x640 xi8 >, tensor <2 x640 x640 xi8 >, tensor <2 x640 x640 xi8 >)
1144+ outs (%2 , %2 : tensor <2 x4096 x640 xi32 >, tensor <2 x4096 x640 xi32 >) {
1145+ ^bb0 (%in: i8 , %in_0: i8 , %in_1: i8 , %out: i32 , %out_2: i32 ):
1146+ %4 = arith.extsi %in : i8 to i32
1147+ %5 = arith.extsi %in_0 : i8 to i32
1148+ %6 = arith.muli %4 , %5 : i32
1149+ %7 = arith.addi %out , %6 : i32
1150+ %8 = arith.extsi %in_1 : i8 to i32
1151+ %9 = arith.muli %7 , %8 : i32
1152+ %10 = arith.addi %out_2 , %9 : i32
1153+ linalg.yield %7 , %10 : i32 , i32
1154+ } -> (tensor <2 x4096 x640 xi32 >, tensor <2 x4096 x640 xi32 >)
1155+ %4 = linalg.generic {
1156+ indexing_maps = [affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>,
1157+ affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>,
1158+ affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>],
1159+ iterator_types = [" parallel" , " parallel" , " parallel" ]}
1160+ ins (%3#1 , %3#0 : tensor <2 x4096 x640 xi32 >, tensor <2 x4096 x640 xi32 >)
1161+ outs (%0 : tensor <2 x4096 x640 xf16 >) {
1162+ ^bb0 (%in: i32 , %in_0: i32 , %out: f16 ):
1163+ %5 = arith.sitofp %in : i32 to f32
1164+ %6 = arith.truncf %5 : f32 to f16
1165+ %7 = arith.sitofp %in_0 : i32 to f32
1166+ %8 = arith.truncf %7 : f32 to f16
1167+ %9 = arith.addf %6 , %8 : f16
1168+ linalg.yield %9 : f16
1169+ } -> tensor <2 x4096 x640 xf16 >
1170+ util.return %4 : tensor <2 x4096 x640 xf16 >
1171+ }
1172+ // CHECK-LABEL: func public @horizontal_fusion2
1173+ // CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
1174+ // CHECK: %[[GENERIC:.+]]:2 = linalg.generic
1175+ // CHECK: %[[TRUNC:.+]] = linalg.generic
1176+ // CHECK-SAME: ins(%[[GENERIC]]#1, %[[GENERIC]]#0 :
1177+ // CHECK: flow.return %[[TRUNC]]
1178+ // CHECK: util.return %[[DISPATCH]]
1179+
1180+ // -----
1181+
1182+ util.func @avoid_use_def_violation_on_consumer_fusion (%arg0 : tensor <?xf32 >,
1183+ %arg1 : tensor <f32 >) -> tensor <f32 > {
1184+ %0 = linalg.generic {
1185+ indexing_maps = [affine_map <(d0 ) -> (d0 )>, affine_map <(d0 ) -> ()>],
1186+ iterator_types = [" reduction" ]}
1187+ ins (%arg0 : tensor <?xf32 >) outs (%arg1 : tensor <f32 >) {
1188+ ^bb0 (%b0 : f32 , %b1 : f32 ):
1189+ %1 = arith.addf %b0 , %b1 : f32
1190+ linalg.yield %1 : f32
1191+ } -> tensor <f32 >
1192+ %1 = util.optimization_barrier %0 : tensor <f32 >
1193+ %2 = tensor.empty () : tensor <f32 >
1194+ %3 = linalg.generic {
1195+ indexing_maps = [affine_map <() -> ()>, affine_map <() -> ()>, affine_map <() -> ()>],
1196+ iterator_types = []}
1197+ ins (%0 , %1 : tensor <f32 >, tensor <f32 >) outs (%2 : tensor <f32 >) {
1198+ ^bb0 (%b0 : f32 , %b1 : f32 , %b2 : f32 ):
1199+ %4 = arith.mulf %b0 , %b1 : f32
1200+ linalg.yield %4 : f32
1201+ } -> tensor <f32 >
1202+ util.return %3 : tensor <f32 >
1203+ }
1204+ // CHECK-LABEL: func public @avoid_use_def_violation_on_consumer_fusion
1205+ // CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region
1206+ // CHECK: %[[GENERIC1:.+]] = linalg.generic
1207+ // CHECK: flow.return %[[GENERIC1]]
1208+ // CHECK: %[[BARRIER:.+]] = util.optimization_barrier %[[DISPATCH1]]
1209+ // CHECK: %[[DISPATCH2:.+]] = flow.dispatch.region
1210+ // CHECK: %[[GENERIC2:.+]] = linalg.generic
1211+ // CHECK-SAME: ins(%[[DISPATCH1]], %[[BARRIER]] :
1212+ // CHECK: flow.return %[[GENERIC2]]
1213+ // CHECK: util.return %[[DISPATCH2]]
0 commit comments