@@ -1083,107 +1083,3 @@ func.func @infusible_pack(%arg0 : tensor<30xf32>) -> tensor<5x6xf32> {
10831083// CHECK: linalg.generic
10841084// CHECK: scf.forall.in_parallel {
10851085// CHECK: linalg.pack
1086-
1087- // -----
1088-
1089- // Adapted from layer normalization. The graph structure is as follows
1090- //
1091- // %14
1092- // / | \
1093- // / %15 %17
1094- // | | / |
1095- // | [%19] |
1096- // %21 | %22
1097- // | | |
1098- // v v v
1099- //
1100- // In particular, %21 and %22 are not users of the "main" tilable
1101- // operation but we still want them to be fused. %19, %21 and %22
1102- // all produce results returned from the function.
1103- //
1104- // Check that everything is fused and that there are three results
1105- // from the loop being produced and returned.
1106- //
1107- // CHECK-LABEL: @multi_result_consumer_fusion
1108- // CHECK-NOT: linalg.generic
1109- // CHECK: %[[LOOP:.+]]:3 = scf.forall (%[[I:.+]], %[[J:.+]]) in (16, 256) shared_outs(%[[OUT0:.+]] = %{{.+}}, %[[OUT1:.+]] = %{{.+}}, %[[OUT2:.+]] = %{{.+}})
1110- // CHECK: %[[v14:.+]] = linalg.generic
1111- // CHECK: arith.divf
1112- // CHECK: %[[v15:.+]] = linalg.generic
1113- // CHECK: arith.subf
1114- // CHECK: %[[v17:.+]] = linalg.generic
1115- // CHECK: arith.divf
1116- // CHECK: math.rsqrt
1117- // CHECK: %[[RES0:.+]] = linalg.generic
1118- // CHECK: arith.mulf
1119- // CHECK: arith.extf
1120- // CHECK: arith.mulf
1121- // CHECK: arith.extf
1122- // CHECK: arith.addf
1123- // CHECK: arith.truncf
1124- // CHECK: %[[RES1:.+]] = linalg.generic {{.*}} ins(%[[v14]] :
1125- // CHECK: arith.truncf
1126- // CHECK: %[[RES2:.+]] = linalg.generic {{.*}} ins(%[[v17]] :
1127- // CHECK: arith.truncf
1128- // CHECK: scf.forall.in_parallel
1129- // CHECK: tensor.parallel_insert_slice %[[RES0]] into %[[OUT0]]
1130- // CHECK: tensor.parallel_insert_slice %[[RES1]] into %[[OUT1]]
1131- // CHECK: tensor.parallel_insert_slice %[[RES2]] into %[[OUT2]]
1132- // CHECK-NOT: linalg.generic
1133- // CHECK: return %[[LOOP]]#0, %[[LOOP]]#1, %[[LOOP]]#2
1134- func.func @multi_result_consumer_fusion (
1135- %6: tensor <16 x256 x2048 xbf16 >,
1136- %7: tensor <2048 xbf16 >,
1137- %8: tensor <2048 xbf16 >,
1138- %10: tensor <16 x256 x2048 xf32 >,
1139- %13: tensor <16 x256 xf32 >
1140- ) -> (
1141- tensor <16 x256 x2048 xbf16 >,
1142- tensor <16 x256 xbf16 >,
1143- tensor <16 x256 xbf16 >
1144- ) {
1145- %cst = arith.constant 0.000000e+00 : f32
1146- %cst_0 = arith.constant 2.048000e+03 : f32
1147- %c0 = arith.constant 0 : index
1148- %9 = tensor.empty () : tensor <16 x256 x2048 xf32 >
1149- %11 = tensor.empty () : tensor <16 x256 xf32 >
1150- %14 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 ) -> (d0 , d1 )>, affine_map <(d0 , d1 ) -> (d0 , d1 )>], iterator_types = [" parallel" , " parallel" ]} ins (%13 : tensor <16 x256 xf32 >) outs (%11 : tensor <16 x256 xf32 >) {
1151- ^bb0 (%in: f32 , %out: f32 ):
1152- %23 = arith.divf %in , %cst_0 : f32
1153- linalg.yield %23 : f32
1154- } -> tensor <16 x256 xf32 >
1155- %15 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>, affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>, affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>], iterator_types = [" parallel" , " parallel" , " parallel" ]} ins (%10 , %14 : tensor <16 x256 x2048 xf32 >, tensor <16 x256 xf32 >) outs (%9 : tensor <16 x256 x2048 xf32 >) {
1156- ^bb0 (%in: f32 , %in_1: f32 , %out: f32 ):
1157- %23 = arith.subf %in , %in_1 : f32
1158- linalg.yield %23 : f32
1159- } -> tensor <16 x256 x2048 xf32 >
1160- %17 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 ) -> (d0 , d1 )>, affine_map <(d0 , d1 ) -> (d0 , d1 )>], iterator_types = [" parallel" , " parallel" ]} ins (%14 : tensor <16 x256 xf32 >) outs (%11 : tensor <16 x256 xf32 >) {
1161- ^bb0 (%in: f32 , %out: f32 ):
1162- %23 = arith.divf %in , %cst_0 : f32
1163- %24 = math.rsqrt %23 : f32
1164- linalg.yield %24 : f32
1165- } -> tensor <16 x256 xf32 >
1166- %18 = tensor.empty () : tensor <16 x256 x2048 xbf16 >
1167- %19 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>, affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>, affine_map <(d0 , d1 , d2 ) -> (d2 )>, affine_map <(d0 , d1 , d2 ) -> (d2 )>, affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>], iterator_types = [" parallel" , " parallel" , " parallel" ]} ins (%15 , %17 , %7 , %8 : tensor <16 x256 x2048 xf32 >, tensor <16 x256 xf32 >, tensor <2048 xbf16 >, tensor <2048 xbf16 >) outs (%18 : tensor <16 x256 x2048 xbf16 >) attrs = {lowering_config = #iree_gpu.lowering_config <{lane_basis = [[1 , 1 , 64 ], [0 , 1 , 2 ]], reduction = [0 , 0 , 256 ], subgroup_basis = [[1 , 1 , 1 ], [0 , 1 , 2 ]], thread = [0 , 0 , 4 ], workgroup = [1 , 1 , 0 ]}>} {
1168- ^bb0 (%in: f32 , %in_1: f32 , %in_2: bf16 , %in_3: bf16 , %out: bf16 ):
1169- %23 = arith.mulf %in , %in_1 : f32
1170- %24 = arith.extf %in_2 : bf16 to f32
1171- %25 = arith.mulf %23 , %24 : f32
1172- %26 = arith.extf %in_3 : bf16 to f32
1173- %27 = arith.addf %25 , %26 : f32
1174- %28 = arith.truncf %27 : f32 to bf16
1175- linalg.yield %28 : bf16
1176- } -> tensor <16 x256 x2048 xbf16 >
1177- %20 = tensor.empty () : tensor <16 x256 xbf16 >
1178- %21 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 ) -> (d0 , d1 )>, affine_map <(d0 , d1 ) -> (d0 , d1 )>], iterator_types = [" parallel" , " parallel" ]} ins (%14 : tensor <16 x256 xf32 >) outs (%20 : tensor <16 x256 xbf16 >) {
1179- ^bb0 (%in: f32 , %out: bf16 ):
1180- %23 = arith.truncf %in : f32 to bf16
1181- linalg.yield %23 : bf16
1182- } -> tensor <16 x256 xbf16 >
1183- %22 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 ) -> (d0 , d1 )>, affine_map <(d0 , d1 ) -> (d0 , d1 )>], iterator_types = [" parallel" , " parallel" ]} ins (%17 : tensor <16 x256 xf32 >) outs (%20 : tensor <16 x256 xbf16 >) {
1184- ^bb0 (%in: f32 , %out: bf16 ):
1185- %23 = arith.truncf %in : f32 to bf16
1186- linalg.yield %23 : bf16
1187- } -> tensor <16 x256 xbf16 >
1188- return %19 , %21 , %22 : tensor <16 x256 x2048 xbf16 >, tensor <16 x256 xbf16 >, tensor <16 x256 xbf16 >
1189- }
0 commit comments