@@ -27,7 +27,6 @@ util.func public @transpose_attention(%arg0: tensor<4x64x32x128xf16>, %arg1: ten
2727 %collapsed = tensor.collapse_shape %7 [[0 ], [1 ], [2 , 3 ]] : tensor <4 x64 x32 x128 xf16 > into tensor <4 x64 x4096 xf16 >
2828 util.return %collapsed : tensor <4 x64 x4096 xf16 >
2929}
30-
3130// CHECK-LABEL: util.func public @transpose_attention
3231// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
3332// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
@@ -76,7 +75,6 @@ util.func public @transposed_attention_masked(%arg0: tensor<4x64x32x128xf16>, %a
7675 %collapsed = tensor.collapse_shape %8 [[0 ], [1 ], [2 , 3 ]] : tensor <4 x64 x32 x128 xf16 > into tensor <4 x64 x4096 xf16 >
7776 util.return %collapsed : tensor <4 x64 x4096 xf16 >
7877}
79-
8078// CHECK-LABEL: util.func public @transposed_attention_masked
8179// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
8280// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
@@ -115,7 +113,6 @@ util.func public @transpose_matmul(%arg0 : tensor<100x100xf16>, %arg1 : tensor<1
115113 } -> tensor <100 x100 xf16 >
116114 util.return %4 : tensor <100 x100 xf16 >
117115}
118-
119116// CHECK-LABEL: util.func public @transpose_matmul
120117// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
121118// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
@@ -156,7 +153,6 @@ util.func public @fuse_generic_gather(
156153 } -> tensor <4 x?x4096 xf32 >
157154 util.return %16 : tensor <4 x?x4096 xf32 >
158155}
159-
160156// CHECK: %[[INDEX0:[a-zA-Z0-9]+]] = arith.index_cast %in : i64 to index
161157// CHECK: %[[INDEX1:[a-zA-Z0-9]+]] = linalg.index 2 : index
162158// CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract %[[TENSOR0:.+]][%[[INDEX0]], %[[INDEX1]]] : tensor<128256x4096xf16>
@@ -198,7 +194,6 @@ util.func public @fuse_generic_gather2(
198194 } -> tensor <4 x?x4096 xf32 >
199195 util.return %16 : tensor <4 x?x4096 xf32 >
200196}
201-
202197// CHECK: %[[INDEX0:[a-zA-Z0-9]+]] = arith.index_cast %in : i64 to index
203198// CHECK: %[[INDEX1:[a-zA-Z0-9]+]] = linalg.index 2 : index
204199// CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract %[[TENSOR0:.+]][%[[INDEX0]], %[[INDEX1]]] : tensor<128256x4096xf16>
@@ -237,7 +232,6 @@ util.func public @fuse_transpose_attention_to_producer(%q: tensor<2x10x4096x64xf
237232 } -> tensor <2 x10 x4096 x64 xf16 >
238233 util.return %attention : tensor <2 x10 x4096 x64 xf16 >
239234}
240-
241235// CHECK-LABEL: util.func public @fuse_transpose_attention_to_producer
242236// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
243237// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
@@ -274,7 +268,6 @@ util.func public @fuse_attention_with_broadcast(%arg0: tensor<4x8x128x?xf16>, %a
274268 } -> tensor <4 x8 x4 x?x32 x128 xf16 >
275269 util.return %1 : tensor <4 x8 x4 x?x32 x128 xf16 >
276270}
277-
278271// CHECK-LABEL: func public @fuse_attention_with_broadcast
279272// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
280273// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
@@ -305,7 +298,6 @@ util.func public @fuse_attention_with_broadcast_transpose(%arg0: tensor<4x?x8x12
305298 } -> tensor <4 x8 x4 x?x32 x128 xf16 >
306299 util.return %1 : tensor <4 x8 x4 x?x32 x128 xf16 >
307300}
308-
309301// CHECK-LABEL: func public @fuse_attention_with_broadcast_transpose
310302// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
311303// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
@@ -356,7 +348,6 @@ util.func public @gather_fusion(%arg0: tensor<2x64x64x640xf16>, %arg1: tensor<2x
356348 } -> tensor <2 x128 x128 x640 xi8 >
357349 util.return %3 : tensor <2 x128 x128 x640 xi8 >
358350}
359-
360351// CHECK-LABEL: util.func public @gather_fusion(
361352// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
362353// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
@@ -423,7 +414,6 @@ util.func public @gather_fusion_compose_maps(%arg0: tensor<2x64x64x640xf16>, %ar
423414 } -> tensor <2 x128 x128 x640 xi8 >
424415 util.return %3 : tensor <2 x128 x128 x640 xi8 >
425416}
426-
427417// CHECK-LABEL: util.func public @gather_fusion_compose_maps(
428418// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
429419// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
@@ -455,3 +445,94 @@ util.func public @gather_fusion_compose_maps(%arg0: tensor<2x64x64x640xf16>, %ar
455445// CHECK: %[[EXTRACT1:.*]] = tensor.extract %[[ARG1]][%[[CAST0]], %[[CAST3]], %[[CAST2]], %[[CAST1]]] : tensor<2x64x64x640xf16>
456446// CHECK: %[[ADDF:.+]] = arith.addf %[[EXTRACT0]], %[[EXTRACT1]] : f16
457447// CHECK: util.return %[[GEN]] : tensor<2x128x128x640xi8>
448+
449+ // -----
450+
451+ util.func public @gather_0d_producer (%arg0 : tensor <f16 >, %arg1 : tensor <100 xindex >, %arg2 : tensor <256 xf16 >) -> (tensor <100 xf32 >) {
452+ %empty0 = tensor.empty () : tensor <256 xf32 >
453+ %0 = linalg.generic {index ing_maps = [affine_map <(d0 ) -> ()>, affine_map <(d0 ) -> (d0 )>, affine_map <(d0 ) -> (d0 )>], iterator_types = [" parallel" ]} ins (%arg0 , %arg2 : tensor <f16 >, tensor <256 xf16 >) outs (%empty0 : tensor <256 xf32 >) {
454+ ^bb0 (%in: f16 , %in0 : f16 , %out: f32 ):
455+ %0 = arith.extf %in : f16 to f32
456+ %1 = arith.extf %in0 : f16 to f32
457+ %2 = arith.addf %0 , %1 : f32
458+ linalg.yield %2 : f32
459+ } -> tensor <256 xf32 >
460+ %empty1 = tensor.empty () : tensor <100 xf32 >
461+ %gather = linalg.generic {index ing_maps = [affine_map <(d0 ) -> (d0 )>, affine_map <(d0 ) -> (d0 )>], iterator_types = [" parallel" ]} ins (%arg1: tensor <100 xindex >) outs (%empty1 : tensor <100 xf32 >) {
462+ ^bb0 (%in: index , %out: f32 ):
463+ %1 = tensor.extract %0 [%in ] : tensor <256 xf32 >
464+ linalg.yield %1 : f32
465+ } -> tensor <100 xf32 >
466+ util.return %gather : tensor <100 xf32 >
467+ }
468+ // CHECK-LABEL: util.func public @gather_0d_producer(
469+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
470+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
471+ // CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]: tensor
472+ // CHECK: %[[GATHER:.+]] = linalg.generic
473+ // CHECK-SAME: ins(%[[ARG1]] : tensor<100xindex>
474+ // CHECK-NEXT: ^bb0(%[[IN:.+]]: index
475+ // CHECK-DAG: %[[EXTRACT0:.+]] = tensor.extract %[[ARG0]][]
476+ // CHECK-DAG: %[[EXTRACT1:.+]] = tensor.extract %[[ARG2]][%[[IN]]]
477+ // CHECK: return %[[GATHER]]
478+
479+ // -----
480+
481+ util.func public @gather_replace_linalg_index (%arg0 : tensor <256 x256 xf16 >, %arg1 : tensor <100 xindex >) -> (tensor <100 xf32 >) {
482+ %empty0 = tensor.empty () : tensor <256 x256 xf32 >
483+ %0 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 ) -> (d0 , d1 )>, affine_map <(d0 , d1 ) -> (d0 , d1 )>], iterator_types = [" parallel" , " parallel" ]} ins (%arg0 : tensor <256 x256 xf16 >) outs (%empty0 : tensor <256 x256 xf32 >) {
484+ ^bb0 (%in: f16 , %out: f32 ):
485+ %0 = arith.extf %in : f16 to f32
486+ %1 = linalg.index 1 : index
487+ %2 = arith.index_cast %1 : index to i32
488+ %3 = arith.uitofp %2 : i32 to f32
489+ %4 = arith.addf %0 , %3 : f32
490+ linalg.yield %4 : f32
491+ } -> tensor <256 x256 xf32 >
492+ %empty1 = tensor.empty () : tensor <100 xf32 >
493+ %gather = linalg.generic {index ing_maps = [affine_map <(d0 ) -> (d0 )>, affine_map <(d0 ) -> (d0 )>], iterator_types = [" parallel" ]} ins (%arg1: tensor <100 xindex >) outs (%empty1 : tensor <100 xf32 >) {
494+ ^bb0 (%in: index , %out: f32 ):
495+ %cst0 = arith.constant 0 : index
496+ %1 = tensor.extract %0 [%cst0 , %in ] : tensor <256 x256 xf32 >
497+ linalg.yield %1 : f32
498+ } -> tensor <100 xf32 >
499+ util.return %gather : tensor <100 xf32 >
500+ }
501+ // CHECK-LABEL: util.func public @gather_replace_linalg_index(
502+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
503+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
504+ // CHECK: %[[GATHER:.+]] = linalg.generic
505+ // CHECK-SAME: ins(%[[ARG1]] : tensor<100xindex>
506+ // CHECK-NEXT: ^bb0(%[[IN:.+]]: index
507+ // CHECK: arith.index_cast %[[IN]]
508+ // CHECK: return %[[GATHER]]
509+
510+ // -----
511+
512+ util.func public @gather_replace_linalg_index_transpose (%arg0 : tensor <256 x256 xf16 >, %arg1 : tensor <100 xindex >, %arg2 : index ) -> (tensor <100 xf32 >) {
513+ %empty0 = tensor.empty () : tensor <256 x256 xf32 >
514+ %0 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 ) -> (d0 , d1 )>, affine_map <(d0 , d1 ) -> (d1 , d0 )>], iterator_types = [" parallel" , " parallel" ]} ins (%arg0 : tensor <256 x256 xf16 >) outs (%empty0 : tensor <256 x256 xf32 >) {
515+ ^bb0 (%in: f16 , %out: f32 ):
516+ %0 = arith.extf %in : f16 to f32
517+ %1 = linalg.index 1 : index
518+ %2 = arith.index_cast %1 : index to i32
519+ %3 = arith.uitofp %2 : i32 to f32
520+ %4 = arith.addf %0 , %3 : f32
521+ linalg.yield %4 : f32
522+ } -> tensor <256 x256 xf32 >
523+ %empty1 = tensor.empty () : tensor <100 xf32 >
524+ %gather = linalg.generic {index ing_maps = [affine_map <(d0 ) -> (d0 )>, affine_map <(d0 ) -> (d0 )>], iterator_types = [" parallel" ]} ins (%arg1: tensor <100 xindex >) outs (%empty1 : tensor <100 xf32 >) {
525+ ^bb0 (%in: index , %out: f32 ):
526+ %1 = tensor.extract %0 [%arg2 , %in ] : tensor <256 x256 xf32 >
527+ linalg.yield %1 : f32
528+ } -> tensor <100 xf32 >
529+ util.return %gather : tensor <100 xf32 >
530+ }
531+ // CHECK-LABEL: util.func public @gather_replace_linalg_index_transpose(
532+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
533+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
534+ // CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]: index
535+ // CHECK: %[[GATHER:.+]] = linalg.generic
536+ // CHECK-SAME: ins(%[[ARG1]] : tensor<100xindex>
537+ // CHECK: arith.index_cast %[[ARG2]]
538+ // CHECK: return %[[GATHER]]
0 commit comments