Skip to content

Commit 16fab03

Browse files
authored
Update LSESeqLen1 pattern matching (#2074)
* Add support for unfolded lse * Clang-format * Add additional comment
1 parent e4ab0c1 commit 16fab03

File tree

2 files changed

+78
-2
lines changed

2 files changed

+78
-2
lines changed

mlir/lib/Conversion/TosaToRock/TosaToRock.cpp

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,6 +1591,11 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
15911591
log(sum(exp(sub(x, x)))) + max(x)
15921592
= log(exp(sub(x, x))) + x
15931593
= sub(x, x) + x
1594+
1595+
Upstream disabled folding of log(exp(..)) by default, so we need to match the
1596+
following two patterns:
1597+
1. The folded pattern: sub(x, x) + x
1598+
2. The unfolded pattern: log(exp(sub(x, x))) + x
15941599
*/
15951600
Value getLSESeqLen1(tosa::SubOp subOp) const {
15961601
if (subOp.getInput1() != subOp.getInput2()) {
@@ -1599,6 +1604,7 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
15991604
}
16001605
Value subInput = subOp.getInput1();
16011606
for (Operation *user : subOp->getUsers()) {
1607+
// Pattern 1: Check for direct add: sub(x, x) + x
16021608
if (tosa::AddOp addOp = dyn_cast<tosa::AddOp>(user)) {
16031609
Value addOpInput1 = addOp.getInput1();
16041610
Value addOpInput2 = addOp.getInput2();
@@ -1613,9 +1619,36 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
16131619
}
16141620
}
16151621
}
1622+
1623+
// Pattern 2: Check for log(exp(sub(x, x))) + x
1624+
tosa::ExpOp expOp = dyn_cast<tosa::ExpOp>(user);
1625+
if (!expOp)
1626+
continue;
1627+
1628+
for (Operation *expUser : expOp->getUsers()) {
1629+
tosa::LogOp logOp = dyn_cast<tosa::LogOp>(expUser);
1630+
if (!logOp)
1631+
continue;
1632+
1633+
for (Operation *logUser : logOp->getUsers()) {
1634+
tosa::AddOp addOp = dyn_cast<tosa::AddOp>(logUser);
1635+
if (!addOp)
1636+
continue;
1637+
1638+
Value addOpInput1 = addOp.getInput1();
1639+
Value addOpInput2 = addOp.getInput2();
1640+
// Check if one input is the log result and the other is the
1641+
// original subInput (x)
1642+
if ((addOpInput1 == logOp.getOutput() && addOpInput2 == subInput) ||
1643+
(addOpInput2 == logOp.getOutput() && addOpInput1 == subInput)) {
1644+
return addOp.getOutput();
1645+
}
1646+
}
1647+
}
16161648
}
16171649
return nullptr;
16181650
}
1651+
16191652
/**
16201653
* Attempts to match and extract a Log-Sum-Exp (LSE) pattern from TOSA
16211654
* operations.
@@ -1980,8 +2013,8 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
19802013
if (hasReduceOp) {
19812014
lse = getLSE(rsum, rmax);
19822015
} else {
1983-
// if there is no reduce op, then we have seq_len=1 and lse is just
1984-
// sub(x, x) + x
2016+
// if there is no reduce op, then we have seq_len=1 and lse is either
2017+
// sub(x, x) + x or log(exp(sub(x, x))) + x
19852018
lse = getLSESeqLen1(cast<tosa::SubOp>(sub));
19862019
}
19872020
// lse has three or four dimensions

mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-lse.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,3 +345,46 @@ func.func @mlir_attention_single_token(%arg0: tensor<128xf32>, %arg1: tensor<256
345345
%collapsed_7 = tensor.collapse_shape %20 [[0, 1, 2]] : tensor<8x1x32xf32> into tensor<256xf32>
346346
return %collapsed_7, %collapsed_4 : tensor<256xf32>, tensor<8xf32>
347347
}
348+
349+
// CHECK-LABEL: @mlir_attention_lse_unfolded
350+
// CHECK: %[[lseBuffer:.+]] = bufferization.alloc_tensor() : tensor<8x1xf32>
351+
// CHECK: %{{.*}}, %[[lseOut:.*]] = rock.attention
352+
// CHECK: lse = %[[lseBuffer]] : tensor<8x1xf32>
353+
// CHECK: %[[lseExpanded:.*]] = tensor.expand_shape %[[lseOut]]
354+
// CHECK: %[[lseCollapsed:.*]] = tensor.collapse_shape %[[lseExpanded]]
355+
// CHECK: return %{{.*}}, %[[lseCollapsed]] : tensor<256xf32>, tensor<8xf32>
356+
func.func private @mlir_attention_lse_unfolded(%arg0: tensor<128xf32>, %arg1: tensor<256xf32>, %arg2: tensor<128xf32>) -> (tensor<256xf32>, tensor<8xf32>) attributes {arch = "##TOKEN_ARCH##", kernel} {
357+
%0 = tosa.const_shape {values = dense<256> : tensor<1xindex>} : () -> !tosa.shape<1>
358+
%1 = tosa.const_shape {values = dense<[8, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
359+
%2 = tosa.const_shape {values = dense<8> : tensor<1xindex>} : () -> !tosa.shape<1>
360+
%3 = tosa.const_shape {values = dense<[2, 4, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
361+
%4 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
362+
%5 = tosa.const_shape {values = dense<[8, 32, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
363+
%6 = tosa.const_shape {values = dense<[8, 1, 32]> : tensor<3xindex>} : () -> !tosa.shape<3>
364+
%7 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
365+
%8 = "tosa.const"() <{values = dense<1.000000e+00> : tensor<2x2x2x1x32xf32>}> : () -> tensor<2x2x2x1x32xf32>
366+
%9 = tosa.const_shape {values = dense<[2, 2, 1, 1, 32]> : tensor<5xindex>} : () -> !tosa.shape<5>
367+
%10 = tosa.const_shape {values = dense<[2, 4, 1, 32]> : tensor<4xindex>} : () -> !tosa.shape<4>
368+
%expanded = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [2, 2, 1, 1, 32] : tensor<128xf32> into tensor<2x2x1x1x32xf32>
369+
%11 = tosa.mul %expanded, %8, %7 : (tensor<2x2x1x1x32xf32>, tensor<2x2x2x1x32xf32>, tensor<1xi8>) -> tensor<2x2x2x1x32xf32>
370+
%expanded_0 = tensor.expand_shape %arg2 [[0, 1, 2, 3, 4]] output_shape [2, 2, 1, 1, 32] : tensor<128xf32> into tensor<2x2x1x1x32xf32>
371+
%12 = tosa.mul %expanded_0, %8, %7 : (tensor<2x2x1x1x32xf32>, tensor<2x2x2x1x32xf32>, tensor<1xi8>) -> tensor<2x2x2x1x32xf32>
372+
%collapsed = tensor.collapse_shape %12 [[0], [1, 2], [3], [4]] : tensor<2x2x2x1x32xf32> into tensor<2x4x1x32xf32>
373+
%13 = tosa.transpose %collapsed {perms = array<i32: 0, 1, 3, 2>} : (tensor<2x4x1x32xf32>) -> tensor<2x4x32x1xf32>
374+
%expanded_1 = tensor.expand_shape %arg1 [[0, 1, 2]] output_shape [8, 1, 32] : tensor<256xf32> into tensor<8x1x32xf32>
375+
%collapsed_2 = tensor.collapse_shape %13 [[0, 1], [2], [3]] : tensor<2x4x32x1xf32> into tensor<8x32x1xf32>
376+
%14 = tosa.matmul %expanded_1, %collapsed_2, %4, %4 {acc_type = f32} : (tensor<8x1x32xf32>, tensor<8x32x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x1x1xf32>
377+
%expanded_3 = tensor.expand_shape %14 [[0, 1], [2], [3]] output_shape [2, 4, 1, 1] : tensor<8x1x1xf32> into tensor<2x4x1x1xf32>
378+
%15 = tosa.sub %expanded_3, %expanded_3 : (tensor<2x4x1x1xf32>, tensor<2x4x1x1xf32>) -> tensor<2x4x1x1xf32>
379+
%16 = tosa.exp %15 : (tensor<2x4x1x1xf32>) -> tensor<2x4x1x1xf32>
380+
%17 = tosa.reciprocal %16 : (tensor<2x4x1x1xf32>) -> tensor<2x4x1x1xf32>
381+
%18 = tosa.mul %16, %17, %7 : (tensor<2x4x1x1xf32>, tensor<2x4x1x1xf32>, tensor<1xi8>) -> tensor<2x4x1x1xf32>
382+
%19 = tosa.log %16 : (tensor<2x4x1x1xf32>) -> tensor<2x4x1x1xf32>
383+
%20 = tosa.add %19, %expanded_3 : (tensor<2x4x1x1xf32>, tensor<2x4x1x1xf32>) -> tensor<2x4x1x1xf32>
384+
%collapsed_4 = tensor.collapse_shape %20 [[0, 1, 2, 3]] : tensor<2x4x1x1xf32> into tensor<8xf32>
385+
%collapsed_5 = tensor.collapse_shape %18 [[0, 1], [2], [3]] : tensor<2x4x1x1xf32> into tensor<8x1x1xf32>
386+
%collapsed_6 = tensor.collapse_shape %11 [[0, 1, 2], [3], [4]] : tensor<2x2x2x1x32xf32> into tensor<8x1x32xf32>
387+
%21 = tosa.matmul %collapsed_5, %collapsed_6, %4, %4 {acc_type = f32} : (tensor<8x1x1xf32>, tensor<8x1x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x1x32xf32>
388+
%collapsed_7 = tensor.collapse_shape %21 [[0, 1, 2]] : tensor<8x1x32xf32> into tensor<256xf32>
389+
return %collapsed_7, %collapsed_4 : tensor<256xf32>, tensor<8xf32>
390+
}

0 commit comments

Comments
 (0)