@@ -345,3 +345,46 @@ func.func @mlir_attention_single_token(%arg0: tensor<128xf32>, %arg1: tensor<256
345345 %collapsed_7 = tensor.collapse_shape %20 [[0 , 1 , 2 ]] : tensor <8 x1 x32 xf32 > into tensor <256 xf32 >
346346 return %collapsed_7 , %collapsed_4 : tensor <256 xf32 >, tensor <8 xf32 >
347347}
348+
349+ // CHECK-LABEL: @mlir_attention_lse_unfolded
350+ // CHECK: %[[lseBuffer:.+]] = bufferization.alloc_tensor() : tensor<8x1xf32>
351+ // CHECK: %{{.*}}, %[[lseOut:.*]] = rock.attention
352+ // CHECK: lse = %[[lseBuffer]] : tensor<8x1xf32>
353+ // CHECK: %[[lseExpanded:.*]] = tensor.expand_shape %[[lseOut]]
354+ // CHECK: %[[lseCollapsed:.*]] = tensor.collapse_shape %[[lseExpanded]]
355+ // CHECK: return %{{.*}}, %[[lseCollapsed]] : tensor<256xf32>, tensor<8xf32>
356+ func.func private @mlir_attention_lse_unfolded (%arg0: tensor <128 xf32 >, %arg1: tensor <256 xf32 >, %arg2: tensor <128 xf32 >) -> (tensor <256 xf32 >, tensor <8 xf32 >) attributes {arch = " ##TOKEN_ARCH##" , kernel } {
357+ %0 = tosa.const_shape {values = dense <256 > : tensor <1 xindex >} : () -> !tosa.shape <1 >
358+ %1 = tosa.const_shape {values = dense <[8 , 1 , 1 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
359+ %2 = tosa.const_shape {values = dense <8 > : tensor <1 xindex >} : () -> !tosa.shape <1 >
360+ %3 = tosa.const_shape {values = dense <[2 , 4 , 1 , 1 ]> : tensor <4 xindex >} : () -> !tosa.shape <4 >
361+ %4 = " tosa.const" () <{values = dense <0.000000e+00 > : tensor <1 xf32 >}> : () -> tensor <1 xf32 >
362+ %5 = tosa.const_shape {values = dense <[8 , 32 , 1 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
363+ %6 = tosa.const_shape {values = dense <[8 , 1 , 32 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
364+ %7 = " tosa.const" () <{values = dense <0 > : tensor <1 xi8 >}> : () -> tensor <1 xi8 >
365+ %8 = " tosa.const" () <{values = dense <1.000000e+00 > : tensor <2 x2 x2 x1 x32 xf32 >}> : () -> tensor <2 x2 x2 x1 x32 xf32 >
366+ %9 = tosa.const_shape {values = dense <[2 , 2 , 1 , 1 , 32 ]> : tensor <5 xindex >} : () -> !tosa.shape <5 >
367+ %10 = tosa.const_shape {values = dense <[2 , 4 , 1 , 32 ]> : tensor <4 xindex >} : () -> !tosa.shape <4 >
368+ %expanded = tensor.expand_shape %arg0 [[0 , 1 , 2 , 3 , 4 ]] output_shape [2 , 2 , 1 , 1 , 32 ] : tensor <128 xf32 > into tensor <2 x2 x1 x1 x32 xf32 >
369+ %11 = tosa.mul %expanded , %8 , %7 : (tensor <2 x2 x1 x1 x32 xf32 >, tensor <2 x2 x2 x1 x32 xf32 >, tensor <1 xi8 >) -> tensor <2 x2 x2 x1 x32 xf32 >
370+ %expanded_0 = tensor.expand_shape %arg2 [[0 , 1 , 2 , 3 , 4 ]] output_shape [2 , 2 , 1 , 1 , 32 ] : tensor <128 xf32 > into tensor <2 x2 x1 x1 x32 xf32 >
371+ %12 = tosa.mul %expanded_0 , %8 , %7 : (tensor <2 x2 x1 x1 x32 xf32 >, tensor <2 x2 x2 x1 x32 xf32 >, tensor <1 xi8 >) -> tensor <2 x2 x2 x1 x32 xf32 >
372+ %collapsed = tensor.collapse_shape %12 [[0 ], [1 , 2 ], [3 ], [4 ]] : tensor <2 x2 x2 x1 x32 xf32 > into tensor <2 x4 x1 x32 xf32 >
373+ %13 = tosa.transpose %collapsed {perms = array<i32 : 0 , 1 , 3 , 2 >} : (tensor <2 x4 x1 x32 xf32 >) -> tensor <2 x4 x32 x1 xf32 >
374+ %expanded_1 = tensor.expand_shape %arg1 [[0 , 1 , 2 ]] output_shape [8 , 1 , 32 ] : tensor <256 xf32 > into tensor <8 x1 x32 xf32 >
375+ %collapsed_2 = tensor.collapse_shape %13 [[0 , 1 ], [2 ], [3 ]] : tensor <2 x4 x32 x1 xf32 > into tensor <8 x32 x1 xf32 >
376+ %14 = tosa.matmul %expanded_1 , %collapsed_2 , %4 , %4 {acc_type = f32 } : (tensor <8 x1 x32 xf32 >, tensor <8 x32 x1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >) -> tensor <8 x1 x1 xf32 >
377+ %expanded_3 = tensor.expand_shape %14 [[0 , 1 ], [2 ], [3 ]] output_shape [2 , 4 , 1 , 1 ] : tensor <8 x1 x1 xf32 > into tensor <2 x4 x1 x1 xf32 >
378+ %15 = tosa.sub %expanded_3 , %expanded_3 : (tensor <2 x4 x1 x1 xf32 >, tensor <2 x4 x1 x1 xf32 >) -> tensor <2 x4 x1 x1 xf32 >
379+ %16 = tosa.exp %15 : (tensor <2 x4 x1 x1 xf32 >) -> tensor <2 x4 x1 x1 xf32 >
380+ %17 = tosa.reciprocal %16 : (tensor <2 x4 x1 x1 xf32 >) -> tensor <2 x4 x1 x1 xf32 >
381+ %18 = tosa.mul %16 , %17 , %7 : (tensor <2 x4 x1 x1 xf32 >, tensor <2 x4 x1 x1 xf32 >, tensor <1 xi8 >) -> tensor <2 x4 x1 x1 xf32 >
382+ %19 = tosa.log %16 : (tensor <2 x4 x1 x1 xf32 >) -> tensor <2 x4 x1 x1 xf32 >
383+ %20 = tosa.add %19 , %expanded_3 : (tensor <2 x4 x1 x1 xf32 >, tensor <2 x4 x1 x1 xf32 >) -> tensor <2 x4 x1 x1 xf32 >
384+ %collapsed_4 = tensor.collapse_shape %20 [[0 , 1 , 2 , 3 ]] : tensor <2 x4 x1 x1 xf32 > into tensor <8 xf32 >
385+ %collapsed_5 = tensor.collapse_shape %18 [[0 , 1 ], [2 ], [3 ]] : tensor <2 x4 x1 x1 xf32 > into tensor <8 x1 x1 xf32 >
386+ %collapsed_6 = tensor.collapse_shape %11 [[0 , 1 , 2 ], [3 ], [4 ]] : tensor <2 x2 x2 x1 x32 xf32 > into tensor <8 x1 x32 xf32 >
387+ %21 = tosa.matmul %collapsed_5 , %collapsed_6 , %4 , %4 {acc_type = f32 } : (tensor <8 x1 x1 xf32 >, tensor <8 x1 x32 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >) -> tensor <8 x1 x32 xf32 >
388+ %collapsed_7 = tensor.collapse_shape %21 [[0 , 1 , 2 ]] : tensor <8 x1 x32 xf32 > into tensor <256 xf32 >
389+ return %collapsed_7 , %collapsed_4 : tensor <256 xf32 >, tensor <8 xf32 >
390+ }
0 commit comments