Skip to content

Commit 078a100

Browse files
authored
Merge pull request #1770 from ROCm/1771-kv-cache-integration-make-sure-the-user-is-broadcasting-the-right-dimensions-for-currentseqlen
Support for causal attention and more strict checks for KV-Cache
2 parents b16004d + facef65 commit 078a100

File tree

4 files changed

+402
-8
lines changed

4 files changed

+402
-8
lines changed

mlir/include/mlir/Dialect/Rock/IR/RockOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def Rock_AttentionOp :
211211
TensorOrMemRefOf<[F32, F16, BF16, I8]>:$queries,
212212
TensorOrMemRefOf<[F32, F16, BF16, I8]>:$keys,
213213
TensorOrMemRefOf<[F32, F16, BF16]>:$values,
214-
Variadic<TensorOrMemRefOf<[F32, F16, BF16, I8]>>:$preSoftmaxElemWiseInputs,
214+
Variadic<AnyTensorOrMemRef>:$preSoftmaxElemWiseInputs,
215215
Optional<TensorOrMemRefOf<[I32]>>:$currentSeqLen,
216216
TensorOrMemRefOf<[F32, BF16, F16]>:$out,
217217
UnitAttr:$qTransposed,
@@ -437,7 +437,7 @@ def Rock_GridwiseAttentionAccelOp :
437437
Arguments<(ins MemRefRankOf<[F32, F16, BF16, I8], [3]>:$queries,
438438
MemRefRankOf<[F32, F16, BF16, I8], [3]>:$keys,
439439
MemRefRankOf<[F32, F16, BF16,], [3]>:$values,
440-
Variadic<TensorOrMemRefOf<[F32, F16, BF16, I8]>>:$preSoftmaxElemWiseInputs,
440+
Variadic<AnyTensorOrMemRef>:$preSoftmaxElemWiseInputs,
441441
Optional<MemRefRankOf<[I32], [1]>>:$currentSeqLen,
442442
MemRefRankOf<[F32, F16, BF16], [3]>:$out,
443443
StrAttr:$arch,

mlir/lib/Conversion/TosaToRock/TosaToRock.cpp

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,23 @@ static bool isElementwiseOp(Operation *op) {
905905
struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
906906
using OpRewritePattern<tosa::MatMulOp>::OpRewritePattern;
907907

908+
FailureOr<Value> getValueNonReshapeOpNonBroadcast(Value val) const {
909+
while (val.getDefiningOp() &&
910+
(val.getDefiningOp<tensor::CollapseShapeOp>() ||
911+
val.getDefiningOp<tensor::ExpandShapeOp>() ||
912+
val.getDefiningOp<tosa::TransposeOp>() ||
913+
val.getDefiningOp<tosa::AddOp>())) {
914+
if (val.getDefiningOp<tosa::AddOp>()) {
915+
auto maybeBroadcast = addBroadcast(val);
916+
if (failed(maybeBroadcast))
917+
return failure();
918+
val = maybeBroadcast.value();
919+
} else
920+
val = val.getDefiningOp()->getOperand(0);
921+
}
922+
return val;
923+
}
924+
908925
Value getValueNonReshapeOp(Value val) const {
909926
while (val.getDefiningOp() &&
910927
(val.getDefiningOp<tensor::CollapseShapeOp>() ||
@@ -1004,8 +1021,35 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
10041021
if (failed(maybeNonZero2))
10051022
return failure();
10061023

1024+
// check that the right dimensions are broadcasted
1025+
auto beforeBroadcastShape =
1026+
dyn_cast<ShapedType>(maybeNonZero2->getType());
1027+
if (beforeBroadcastShape) {
1028+
auto shape = beforeBroadcastShape.getShape();
1029+
if (beforeBroadcastShape.getRank() > 2 &&
1030+
!llvm::all_of(shape.slice(2), [](int32_t v) { return v == 1; }))
1031+
return failure();
1032+
} else {
1033+
return failure();
1034+
}
1035+
10071036
Value currentSeqLen = getValueNonReshapeOp(maybeNonZero2.value());
10081037
Value result = select.getOnFalse();
1038+
1039+
// currentSeqLen must be of i32 type
1040+
auto currentSeqLenShape = dyn_cast<ShapedType>(currentSeqLen.getType());
1041+
if (!currentSeqLenShape ||
1042+
!currentSeqLenShape.getElementType().isInteger(32))
1043+
return failure();
1044+
1045+
// we'll check now if currentSeqLen comes from a block argument
1046+
FailureOr<Value> mustBeBlockArg =
1047+
getValueNonReshapeOpNonBroadcast(currentSeqLen);
1048+
1049+
if (failed(mustBeBlockArg) ||
1050+
!isa<BlockArgument>(mustBeBlockArg.value()))
1051+
return failure();
1052+
10091053
return std::make_pair(result, currentSeqLen);
10101054
}
10111055
}
@@ -1216,9 +1260,8 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
12161260
LogicalResult match(tosa::MatMulOp op) const override {
12171261
FailureOr<std::tuple<Value, bool, Value>> softmaxInputResult =
12181262
maybeSoftmax(op.getA());
1219-
if (failed(softmaxInputResult)) {
1263+
if (failed(softmaxInputResult))
12201264
return failure();
1221-
}
12221265

12231266
Value softmaxInput, currentSeqLen;
12241267
bool hasReduceOp;
@@ -1245,12 +1288,10 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
12451288
LLVM_DEBUG(llvm::dbgs()
12461289
<< "first matmul = " << maybeFirstMatMul.value() << "\n");
12471290
LLVM_DEBUG(llvm::dbgs() << "hasReduceOp = " << hasReduceOp << "\n");
1248-
if (isDotProduct && hasReduceOp) {
1291+
if (isDotProduct && hasReduceOp)
12491292
return failure();
1250-
}
1251-
if (!isDotProduct && !hasReduceOp) {
1293+
if (!isDotProduct && !hasReduceOp)
12521294
return failure();
1253-
}
12541295
} else {
12551296
LLVM_DEBUG(llvm::dbgs() << "first matmul not found\n");
12561297
}

0 commit comments

Comments
 (0)