@@ -321,3 +321,137 @@ util.func public @fuse_attention_with_broadcast_transpose(%arg0: tensor<4x?x8x12
321321// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5)>
322322// CHECK-SAME: ins(%[[ARG1]], %[[ARG2]], %[[ARG0]], %[[ARG3]], %[[ARG4]] :
323323// CHECK: util.return %[[ATTENTION]]
324+
325+ // -----
326+
327+ util.func public @gather_fusion (%arg0: tensor <2 x64 x64 x640 xf16 >, %arg1: tensor <2 x64 x64 x640 xf16 >, %arg2: tensor <2 xi64 >, %arg3: tensor <640 xi64 >, %arg4: tensor <128 xi64 >, %arg5: tensor <640 xf16 >, %arg6: tensor <f32 >) -> tensor <2 x128 x128 x640 xi8 > {
328+ %cst = arith.constant -1.280000e+02 : f16
329+ %cst_0 = arith.constant 1.270000e+02 : f16
330+ %0 = tensor.empty () : tensor <2 x128 x128 x640 xi8 >
331+ %1 = tensor.empty () : tensor <2 x640 x64 x64 xf32 >
332+ %2 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d2 , d3 , d1 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d2 , d3 , d1 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>], iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ]} ins (%arg0 , %arg1 : tensor <2 x64 x64 x640 xf16 >, tensor <2 x64 x64 x640 xf16 >) outs (%1 : tensor <2 x640 x64 x64 xf32 >) {
333+ ^bb0 (%in: f16 , %in_1: f16 , %out: f32 ):
334+ %4 = arith.addf %in , %in_1 : f16
335+ %5 = arith.extf %4 : f16 to f32
336+ linalg.yield %5 : f32
337+ } -> tensor <2 x640 x64 x64 xf32 >
338+ %3 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 ) -> (d0 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d3 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d1 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d2 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d3 )>, affine_map <(d0 , d1 , d2 , d3 ) -> ()>, affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>], iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ]} ins (%arg2 , %arg3 , %arg4 , %arg4 , %arg5 , %arg6 : tensor <2 xi64 >, tensor <640 xi64 >, tensor <128 xi64 >, tensor <128 xi64 >, tensor <640 xf16 >, tensor <f32 >) outs (%0 : tensor <2 x128 x128 x640 xi8 >) {
339+ ^bb0 (%in: i64 , %in_1: i64 , %in_2: i64 , %in_3: i64 , %in_4: f16 , %in_5: f32 , %out: i8 ):
340+ %4 = arith.index_cast %in : i64 to index
341+ %5 = arith.index_cast %in_1 : i64 to index
342+ %6 = arith.index_cast %in_2 : i64 to index
343+ %7 = arith.index_cast %in_3 : i64 to index
344+ %extracted = tensor.extract %2 [%4 , %5 , %6 , %7 ] : tensor <2 x640 x64 x64 xf32 >
345+ %8 = arith.truncf %extracted : f32 to f16
346+ %9 = arith.mulf %8 , %in_4 : f16
347+ %10 = arith.truncf %in_5 : f32 to f16
348+ %11 = arith.divf %9 , %10 : f16
349+ %12 = math.roundeven %11 : f16
350+ %13 = arith.cmpf ult , %12 , %cst : f16
351+ %14 = arith.select %13 , %cst , %12 : f16
352+ %15 = arith.cmpf ugt , %14 , %cst_0 : f16
353+ %16 = arith.select %15 , %cst_0 , %14 : f16
354+ %17 = arith.fptosi %16 : f16 to i8
355+ linalg.yield %17 : i8
356+ } -> tensor <2 x128 x128 x640 xi8 >
357+ util.return %3 : tensor <2 x128 x128 x640 xi8 >
358+ }
359+
360+ // CHECK-LABEL: util.func public @gather_fusion(
361+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
362+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
363+ // CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]: tensor
364+ // CHECK-SAME: %[[ARG3:[A-Za-z0-9]+]]: tensor
365+ // CHECK-SAME: %[[ARG4:[A-Za-z0-9]+]]: tensor
366+ // CHECK-SAME: %[[ARG5:[A-Za-z0-9]+]]: tensor
367+ // CHECK-SAME: %[[ARG6:[A-Za-z0-9]+]]: tensor
368+ // CHECK: %[[GEN:.+]] = linalg.generic
369+ // CHECK-SAME: indexing_maps =
370+ // CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0)>,
371+ // CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d3)>,
372+ // CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1)>,
373+ // CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d2)>,
374+ // CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d3)>,
375+ // CHECK-SAME: affine_map<(d0, d1, d2, d3) -> ()>,
376+ // CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
377+ // CHECK-SAME: ins(%[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG4]], %[[ARG5]], %[[ARG6]]
378+ // CHECK: ^bb0(
379+ // CHECK-SAME: %[[IN0:[_a-zA-Z0-9]+]]: i64,
380+ // CHECK-SAME: %[[IN1:[_a-zA-Z0-9]+]]: i64,
381+ // CHECK-SAME: %[[IN2:[_a-zA-Z0-9]+]]: i64,
382+ // CHECK-SAME: %[[IN3:[_a-zA-Z0-9]+]]: i64,
383+ // CHECK-DAG: %[[CAST0:.+]] = arith.index_cast %[[IN0]] : i64 to index
384+ // CHECK-DAG: %[[CAST1:.+]] = arith.index_cast %[[IN1]] : i64 to index
385+ // CHECK-DAG: %[[CAST2:.+]] = arith.index_cast %[[IN2]] : i64 to index
386+ // CHECK-DAG: %[[CAST3:.+]] = arith.index_cast %[[IN3]] : i64 to index
387+ // CHECK: %[[EXTRACT0:.*]] = tensor.extract %[[ARG0]][%[[CAST0]], %[[CAST2]], %[[CAST3]], %[[CAST1]]] : tensor<2x64x64x640xf16>
388+ // CHECK: %[[EXTRACT1:.*]] = tensor.extract %[[ARG1]][%[[CAST0]], %[[CAST2]], %[[CAST3]], %[[CAST1]]] : tensor<2x64x64x640xf16>
389+ // CHECK: %[[ADDF:.+]] = arith.addf %[[EXTRACT0]], %[[EXTRACT1]] : f16
390+ // CHECK: util.return %[[GEN]] : tensor<2x128x128x640xi8>
391+
392+ // -----
393+
394+ util.func public @gather_fusion_compose_maps (%arg0: tensor <2 x64 x64 x640 xf16 >, %arg1: tensor <2 x64 x64 x640 xf16 >, %arg2: tensor <2 xi64 >, %arg3: tensor <640 xi64 >, %arg4: tensor <128 xi64 >, %arg5: tensor <640 xf16 >, %arg6: tensor <f32 >) -> tensor <2 x128 x128 x640 xi8 > {
395+ %cst = arith.constant -1.280000e+02 : f16
396+ %cst_0 = arith.constant 1.270000e+02 : f16
397+ %0 = tensor.empty () : tensor <2 x128 x128 x640 xi8 >
398+ %1 = tensor.empty () : tensor <2 x640 x64 x64 xf32 >
399+ %2 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d3 , d2 , d1 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d2 , d3 , d1 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d3 , d2 )>], iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ]} ins (%arg0 , %arg1 : tensor <2 x64 x64 x640 xf16 >, tensor <2 x64 x64 x640 xf16 >) outs (%1 : tensor <2 x640 x64 x64 xf32 >) {
400+ ^bb0 (%in: f16 , %in_1: f16 , %out: f32 ):
401+ %4 = arith.addf %in , %in_1 : f16
402+ %5 = arith.extf %4 : f16 to f32
403+ linalg.yield %5 : f32
404+ } -> tensor <2 x640 x64 x64 xf32 >
405+ %3 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 ) -> (d0 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d3 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d1 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d2 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d3 )>, affine_map <(d0 , d1 , d2 , d3 ) -> ()>, affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>], iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ]} ins (%arg2 , %arg3 , %arg4 , %arg4 , %arg5 , %arg6 : tensor <2 xi64 >, tensor <640 xi64 >, tensor <128 xi64 >, tensor <128 xi64 >, tensor <640 xf16 >, tensor <f32 >) outs (%0 : tensor <2 x128 x128 x640 xi8 >) {
406+ ^bb0 (%in: i64 , %in_1: i64 , %in_2: i64 , %in_3: i64 , %in_4: f16 , %in_5: f32 , %out: i8 ):
407+ %4 = arith.index_cast %in : i64 to index
408+ %5 = arith.index_cast %in_1 : i64 to index
409+ %6 = arith.index_cast %in_2 : i64 to index
410+ %7 = arith.index_cast %in_3 : i64 to index
411+ %extracted = tensor.extract %2 [%4 , %5 , %6 , %7 ] : tensor <2 x640 x64 x64 xf32 >
412+ %8 = arith.truncf %extracted : f32 to f16
413+ %9 = arith.mulf %8 , %in_4 : f16
414+ %10 = arith.truncf %in_5 : f32 to f16
415+ %11 = arith.divf %9 , %10 : f16
416+ %12 = math.roundeven %11 : f16
417+ %13 = arith.cmpf ult , %12 , %cst : f16
418+ %14 = arith.select %13 , %cst , %12 : f16
419+ %15 = arith.cmpf ugt , %14 , %cst_0 : f16
420+ %16 = arith.select %15 , %cst_0 , %14 : f16
421+ %17 = arith.fptosi %16 : f16 to i8
422+ linalg.yield %17 : i8
423+ } -> tensor <2 x128 x128 x640 xi8 >
424+ util.return %3 : tensor <2 x128 x128 x640 xi8 >
425+ }
426+
427+ // CHECK-LABEL: util.func public @gather_fusion_compose_maps(
428+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
429+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
430+ // CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]: tensor
431+ // CHECK-SAME: %[[ARG3:[A-Za-z0-9]+]]: tensor
432+ // CHECK-SAME: %[[ARG4:[A-Za-z0-9]+]]: tensor
433+ // CHECK-SAME: %[[ARG5:[A-Za-z0-9]+]]: tensor
434+ // CHECK-SAME: %[[ARG6:[A-Za-z0-9]+]]: tensor
435+ // CHECK: %[[GEN:.+]] = linalg.generic
436+ // CHECK-SAME: indexing_maps =
437+ // CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0)>,
438+ // CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d3)>,
439+ // CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1)>,
440+ // CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d2)>,
441+ // CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d3)>,
442+ // CHECK-SAME: affine_map<(d0, d1, d2, d3) -> ()>,
443+ // CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
444+ // CHECK-SAME: ins(%[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG4]], %[[ARG5]], %[[ARG6]]
445+ // CHECK: ^bb0(
446+ // CHECK-SAME: %[[IN0:[_a-zA-Z0-9]+]]: i64,
447+ // CHECK-SAME: %[[IN1:[_a-zA-Z0-9]+]]: i64,
448+ // CHECK-SAME: %[[IN2:[_a-zA-Z0-9]+]]: i64,
449+ // CHECK-SAME: %[[IN3:[_a-zA-Z0-9]+]]: i64,
450+ // CHECK-DAG: %[[CAST0:.+]] = arith.index_cast %[[IN0]] : i64 to index
451+ // CHECK-DAG: %[[CAST1:.+]] = arith.index_cast %[[IN1]] : i64 to index
452+ // CHECK-DAG: %[[CAST2:.+]] = arith.index_cast %[[IN2]] : i64 to index
453+ // CHECK-DAG: %[[CAST3:.+]] = arith.index_cast %[[IN3]] : i64 to index
454+ // CHECK: %[[EXTRACT0:.*]] = tensor.extract %[[ARG0]][%[[CAST0]], %[[CAST2]], %[[CAST3]], %[[CAST1]]] : tensor<2x64x64x640xf16>
455+ // CHECK: %[[EXTRACT1:.*]] = tensor.extract %[[ARG1]][%[[CAST0]], %[[CAST3]], %[[CAST2]], %[[CAST1]]] : tensor<2x64x64x640xf16>
456+ // CHECK: %[[ADDF:.+]] = arith.addf %[[EXTRACT0]], %[[EXTRACT1]] : f16
457+ // CHECK: util.return %[[GEN]] : tensor<2x128x128x640xi8>
0 commit comments