@@ -208,6 +208,8 @@ util.func public @fuse_generic_gather2(
208208// CHECK-NEXT: %[[RES4:[a-zA-Z0-9]+]] = arith.addf %[[RES2]], %[[RES3]] : f32
209209// CHECK-NEXT: linalg.yield %[[RES4]] : f32
210210
211+ // -----
212+
211213util.func public @fuse_transpose_attention_to_producer (%q: tensor <2 x10 x4096 x64 xf16 >, %k: tensor <2 x10 x4096 x64 xf16 >, %quantized_v: tensor <2 x10 x4096 x64 xi32 >, %quant_offset: tensor <10 x64 xi32 >, %quant_scale: tensor <10 x64 xf32 >, %scale: f16 ) -> tensor <2 x10 x4096 x64 xf16 > {
212214 // Dequantize int-quantization of V
213215 %init_dequant = tensor.empty () : tensor <2 x10 x4096 x64 xf16 >
@@ -258,3 +260,64 @@ util.func public @fuse_transpose_attention_to_producer(%q: tensor<2x10x4096x64xf
258260// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
259261// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
260262// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[DEQUANT_V]], %[[ARG5]]
263+
264+ // -----
265+
266+ util.func public @fuse_attention_with_broadcast (%arg0: tensor <4 x8 x128 x?xf16 >, %arg1: tensor <4 x8 x4 x?x32 x128 xf16 >, %arg2: tensor <4 x8 x4 x?x128 xf16 >, %arg3: f16 , %arg4: tensor <4 x8 x4 x?x32 x?xf16 >, %arg5: tensor <4 x8 x4 x?x32 x128 xf16 >, %arg6: tensor <4 x8 x4 x128 x?xf16 >) -> tensor <4 x8 x4 x?x32 x128 xf16 > {
267+ %0 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d3 , d4 )>, affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d2 , d3 , d4 )>], iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" , " parallel" ]} ins (%arg0 : tensor <4 x8 x128 x?xf16 >) outs (%arg6 : tensor <4 x8 x4 x128 x?xf16 >) {
268+ ^bb0 (%in: f16 , %out: f16 ):
269+ linalg.yield %in : f16
270+ } -> tensor <4 x8 x4 x128 x?xf16 >
271+ %1 = iree_linalg_ext.attention {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 , d3 , d4 , d6 )>, affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 , d7 , d6 )>, affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 , d5 , d7 )>, affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> ()>, affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 , d3 , d4 , d7 )>, affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 , d3 , d4 , d5 )>]} ins (%arg1 , %arg2 , %0 , %arg3 , %arg4 : tensor <4 x8 x4 x?x32 x128 xf16 >, tensor <4 x8 x4 x?x128 xf16 >, tensor <4 x8 x4 x128 x?xf16 >, f16 , tensor <4 x8 x4 x?x32 x?xf16 >) outs (%arg5 : tensor <4 x8 x4 x?x32 x128 xf16 >) {
272+ ^bb0 (%arg7: f32 ):
273+ iree_linalg_ext.yield %arg7 : f32
274+ } -> tensor <4 x8 x4 x?x32 x128 xf16 >
275+ util.return %1 : tensor <4 x8 x4 x?x32 x128 xf16 >
276+ }
277+
278+ // CHECK-LABEL: func public @fuse_attention_with_broadcast
279+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
280+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
281+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]:
282+ // CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]:
283+ // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]:
284+ // CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
285+ // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d6)>,
286+ // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d7, d6)>,
287+ // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d7)>,
288+ // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>,
289+ // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d7)>
290+ // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5)>
291+ // CHECK-SAME: ins(%[[ARG1]], %[[ARG2]], %[[ARG0]], %[[ARG3]], %[[ARG4]] :
292+ // CHECK: util.return %[[ATTENTION]]
293+
294+
295+ // -----
296+
297+ util.func public @fuse_attention_with_broadcast_transpose (%arg0: tensor <4 x?x8 x128 xf16 >, %arg1: tensor <4 x8 x4 x?x32 x128 xf16 >, %arg2: tensor <4 x8 x4 x?x128 xf16 >, %arg3: f16 , %arg4: tensor <4 x8 x4 x?x32 x?xf16 >, %arg5: tensor <4 x8 x4 x?x32 x128 xf16 >, %arg6: tensor <4 x8 x4 x128 x?xf16 >) -> tensor <4 x8 x4 x?x32 x128 xf16 > {
298+ %0 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d2 , d4 )>, affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d2 , d3 , d4 , d1 )>], iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" , " parallel" ]} ins (%arg0 : tensor <4 x?x8 x128 xf16 >) outs (%arg6 : tensor <4 x8 x4 x128 x?xf16 >) {
299+ ^bb0 (%in: f16 , %out: f16 ):
300+ linalg.yield %in : f16
301+ } -> tensor <4 x8 x4 x128 x?xf16 >
302+ %1 = iree_linalg_ext.attention {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 , d3 , d4 , d6 )>, affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 , d7 , d6 )>, affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 , d5 , d7 )>, affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> ()>, affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 , d3 , d4 , d7 )>, affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 , d7 ) -> (d0 , d1 , d2 , d3 , d4 , d5 )>]} ins (%arg1 , %arg2 , %0 , %arg3 , %arg4 : tensor <4 x8 x4 x?x32 x128 xf16 >, tensor <4 x8 x4 x?x128 xf16 >, tensor <4 x8 x4 x128 x?xf16 >, f16 , tensor <4 x8 x4 x?x32 x?xf16 >) outs (%arg5 : tensor <4 x8 x4 x?x32 x128 xf16 >) {
303+ ^bb0 (%arg7: f32 ):
304+ iree_linalg_ext.yield %arg7 : f32
305+ } -> tensor <4 x8 x4 x?x32 x128 xf16 >
306+ util.return %1 : tensor <4 x8 x4 x?x32 x128 xf16 >
307+ }
308+
309+ // CHECK-LABEL: func public @fuse_attention_with_broadcast_transpose
310+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
311+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
312+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]:
313+ // CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]:
314+ // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]:
315+ // CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
316+ // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d6)>,
317+ // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d7, d6)>,
318+ // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d5)>,
319+ // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>,
320+ // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d7)>
321+ // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5)>
322+ // CHECK-SAME: ins(%[[ARG1]], %[[ARG2]], %[[ARG0]], %[[ARG3]], %[[ARG4]] :
323+ // CHECK: util.return %[[ATTENTION]]
0 commit comments