Skip to content

Commit 1655ca0

Browse files
[6.4][Backport] Backport some attention bugfixes + causal attention (#1811)
Support for causal attention and more strict checks for KV-Cache * Fix attention bugs (swap thread and iter when Q LDS is bypassed and bf16 tests) (#1797) * Fix some attention bugs: - do now swap thread and iter subdims for Q if we are bypassing LDS - use f32 attention in CPU code - fix bug in maskKVCacheTosa for bf16 --------- Co-authored-by: Daniel Hernandez-Juarez <[email protected]> Co-authored-by: Daniel Hernandez-Juarez <[email protected]>
1 parent be36966 commit 1655ca0

File tree

11 files changed

+543
-108
lines changed

11 files changed

+543
-108
lines changed

mlir/include/mlir/Dialect/Rock/IR/RockOps.td

Lines changed: 32 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -205,27 +205,22 @@ def Rock_ReduceOp :
205205
}];
206206
}
207207

208-
def Rock_AttentionOp :
209-
Rock_Op<"attention", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, RockFusionRoot, AttrSizedOperandSegments]>,
210-
Arguments<(ins
211-
TensorOrMemRefOf<[F32, F16, BF16, I8]>:$queries,
212-
TensorOrMemRefOf<[F32, F16, BF16, I8]>:$keys,
213-
TensorOrMemRefOf<[F32, F16, BF16]>:$values,
214-
Variadic<TensorOrMemRefOf<[F32, F16, BF16, I8]>>:$preSoftmaxElemWiseInputs,
215-
Optional<TensorOrMemRefOf<[I32]>>:$currentSeqLen,
216-
TensorOrMemRefOf<[F32, BF16, F16]>:$out,
217-
UnitAttr:$qTransposed,
218-
UnitAttr:$kTransposed,
219-
UnitAttr:$vTransposed,
220-
UnitAttr:$oTransposed,
221-
StrAttr:$arch,
222-
Rock_GemmFeaturesAttr:$features,
223-
OptionalAttr<I32Attr>:$numCU,
224-
OptionalAttr<RockTuningParamAttrInterface>:$params0,
225-
OptionalAttr<RockTuningParamAttrInterface>:$params1,
226-
I32Attr:$firstGemmIdx
227-
)>,
228-
Results<(outs Optional<TensorOf<[F32, F16, BF16]>>:$result)> {
208+
def Rock_AttentionOp
209+
: Rock_Op<"attention", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
210+
RockFusionRoot, AttrSizedOperandSegments]>,
211+
Arguments<(ins TensorOrMemRefOf<[F32, F16, BF16, I8]>:$queries,
212+
TensorOrMemRefOf<[F32, F16, BF16, I8]>:$keys,
213+
TensorOrMemRefOf<[F32, F16, BF16]>:$values,
214+
Variadic<AnyTensorOrMemRef>:$preSoftmaxElemWiseInputs,
215+
Optional<TensorOrMemRefOf<[I32]>>:$currentSeqLen,
216+
TensorOrMemRefOf<[F32, BF16, F16]>:$out, UnitAttr:$qTransposed,
217+
UnitAttr:$kTransposed, UnitAttr:$vTransposed, UnitAttr:$oTransposed,
218+
StrAttr:$arch, Rock_GemmFeaturesAttr:$features,
219+
OptionalAttr<I32Attr>:$numCU,
220+
OptionalAttr<RockTuningParamAttrInterface>:$params0,
221+
OptionalAttr<RockTuningParamAttrInterface>:$params1,
222+
I32Attr:$firstGemmIdx)>,
223+
Results<(outs Optional<TensorOf<[F32, F16, BF16]>>:$result)> {
229224
let summary = "Attention operation of transformer models";
230225
let description = [{
231226
Performs the operation out = SOFTMAX((queries * keys) .* scale) * values.
@@ -432,24 +427,22 @@ def Rock_GridwiseGemmAccelOp :
432427
}
433428

434429
// gridwise_attention_accel
435-
def Rock_GridwiseAttentionAccelOp :
436-
Rock_Op<"gridwise_attention_accel", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, RockFusionRoot, AttrSizedOperandSegments]>,
437-
Arguments<(ins MemRefRankOf<[F32, F16, BF16, I8], [3]>:$queries,
438-
MemRefRankOf<[F32, F16, BF16, I8], [3]>:$keys,
439-
MemRefRankOf<[F32, F16, BF16,], [3]>:$values,
440-
Variadic<TensorOrMemRefOf<[F32, F16, BF16, I8]>>:$preSoftmaxElemWiseInputs,
441-
Optional<MemRefRankOf<[I32], [1]>>:$currentSeqLen,
442-
MemRefRankOf<[F32, F16, BF16], [3]>:$out,
443-
StrAttr:$arch,
444-
Rock_GemmFeaturesAttr:$features,
445-
I32Attr:$blockSize,
446-
I32Attr:$gridSize,
447-
UnitAttr:$disableQBypassLDS,
448-
OptionalAttr<IndexAttr>:$prePadG0M,
449-
OptionalAttr<IndexAttr>:$prePadG0N,
450-
RockAccelTuningParamAttrInterface:$params0,
451-
RockAccelTuningParamAttrInterface:$params1,
452-
I32Attr:$firstGemmIdx)> {
430+
def Rock_GridwiseAttentionAccelOp
431+
: Rock_Op<"gridwise_attention_accel",
432+
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
433+
RockFusionRoot, AttrSizedOperandSegments]>,
434+
Arguments<(ins MemRefRankOf<[F32, F16, BF16, I8], [3]>:$queries,
435+
MemRefRankOf<[F32, F16, BF16, I8], [3]>:$keys,
436+
MemRefRankOf<[F32, F16, BF16, ], [3]>:$values,
437+
Variadic<AnyTensorOrMemRef>:$preSoftmaxElemWiseInputs,
438+
Optional<MemRefRankOf<[I32], [1]>>:$currentSeqLen,
439+
MemRefRankOf<[F32, F16, BF16], [3]>:$out, StrAttr:$arch,
440+
Rock_GemmFeaturesAttr:$features, I32Attr:$blockSize,
441+
I32Attr:$gridSize, UnitAttr:$disableQBypassLDS,
442+
OptionalAttr<IndexAttr>:$prePadG0M,
443+
OptionalAttr<IndexAttr>:$prePadG0N,
444+
RockAccelTuningParamAttrInterface:$params0,
445+
RockAccelTuningParamAttrInterface:$params1, I32Attr:$firstGemmIdx)> {
453446
let summary = "Gridwise attention accelerated version";
454447
let description = [{
455448
The `rock.gridwise_attention_accel` op computes gridwise attention with acceleration.

mlir/lib/Conversion/TosaToRock/TosaToRock.cpp

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,23 @@ static bool isElementwiseOp(Operation *op) {
905905
struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
906906
using OpRewritePattern<tosa::MatMulOp>::OpRewritePattern;
907907

908+
FailureOr<Value> getValueNonReshapeOpNonBroadcast(Value val) const {
909+
while (val.getDefiningOp() &&
910+
(val.getDefiningOp<tensor::CollapseShapeOp>() ||
911+
val.getDefiningOp<tensor::ExpandShapeOp>() ||
912+
val.getDefiningOp<tosa::TransposeOp>() ||
913+
val.getDefiningOp<tosa::AddOp>())) {
914+
if (val.getDefiningOp<tosa::AddOp>()) {
915+
auto maybeBroadcast = addBroadcast(val);
916+
if (failed(maybeBroadcast))
917+
return failure();
918+
val = maybeBroadcast.value();
919+
} else
920+
val = val.getDefiningOp()->getOperand(0);
921+
}
922+
return val;
923+
}
924+
908925
Value getValueNonReshapeOp(Value val) const {
909926
while (val.getDefiningOp() &&
910927
(val.getDefiningOp<tensor::CollapseShapeOp>() ||
@@ -1004,8 +1021,35 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
10041021
if (failed(maybeNonZero2))
10051022
return failure();
10061023

1024+
// check that the right dimensions are broadcasted
1025+
auto beforeBroadcastShape =
1026+
dyn_cast<ShapedType>(maybeNonZero2->getType());
1027+
if (beforeBroadcastShape) {
1028+
auto shape = beforeBroadcastShape.getShape();
1029+
if (beforeBroadcastShape.getRank() > 2 &&
1030+
!llvm::all_of(shape.slice(2), [](int32_t v) { return v == 1; }))
1031+
return failure();
1032+
} else {
1033+
return failure();
1034+
}
1035+
10071036
Value currentSeqLen = getValueNonReshapeOp(maybeNonZero2.value());
10081037
Value result = select.getOnFalse();
1038+
1039+
// currentSeqLen must be of i32 type
1040+
auto currentSeqLenShape = dyn_cast<ShapedType>(currentSeqLen.getType());
1041+
if (!currentSeqLenShape ||
1042+
!currentSeqLenShape.getElementType().isInteger(32))
1043+
return failure();
1044+
1045+
// we'll check now if currentSeqLen comes from a block argument
1046+
FailureOr<Value> mustBeBlockArg =
1047+
getValueNonReshapeOpNonBroadcast(currentSeqLen);
1048+
1049+
if (failed(mustBeBlockArg) ||
1050+
!isa<BlockArgument>(mustBeBlockArg.value()))
1051+
return failure();
1052+
10091053
return std::make_pair(result, currentSeqLen);
10101054
}
10111055
}
@@ -1216,9 +1260,8 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
12161260
LogicalResult match(tosa::MatMulOp op) const override {
12171261
FailureOr<std::tuple<Value, bool, Value>> softmaxInputResult =
12181262
maybeSoftmax(op.getA());
1219-
if (failed(softmaxInputResult)) {
1263+
if (failed(softmaxInputResult))
12201264
return failure();
1221-
}
12221265

12231266
Value softmaxInput, currentSeqLen;
12241267
bool hasReduceOp;
@@ -1245,12 +1288,10 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
12451288
LLVM_DEBUG(llvm::dbgs()
12461289
<< "first matmul = " << maybeFirstMatMul.value() << "\n");
12471290
LLVM_DEBUG(llvm::dbgs() << "hasReduceOp = " << hasReduceOp << "\n");
1248-
if (isDotProduct && hasReduceOp) {
1291+
if (isDotProduct && hasReduceOp)
12491292
return failure();
1250-
}
1251-
if (!isDotProduct && !hasReduceOp) {
1293+
if (!isDotProduct && !hasReduceOp)
12521294
return failure();
1253-
}
12541295
} else {
12551296
LLVM_DEBUG(llvm::dbgs() << "first matmul not found\n");
12561297
}

mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,9 @@ struct GridwiseAttentionAccelRewritePattern
868868
return failure();
869869
}
870870
} else {
871+
assert(!ldsLayoutCfg.doSwapThreadIterSubDims &&
872+
"doSwapThreadIterSubDims must be false if the destination buffer "
873+
"is private memory");
871874
accel::AccelEmitterParams accelEmitterParams = accelEmitter.getParams();
872875
int64_t dRepeats = (nonKDimName == "m" ? accelEmitterParams.mRepeats
873876
: accelEmitterParams.nRepeats);
@@ -1754,6 +1757,16 @@ struct GridwiseAttentionAccelRewritePattern
17541757
}
17551758
LDSLayoutConfigDim ldsLayoutCfgNG0 = getLDSLayoutConfigDim(
17561759
elemTypeQ, gemm0kpack, maybeVectorDimInfoQ.value());
1760+
if (doBypassLDSForQ) {
1761+
ldsLayoutCfgNG0.doSwapThreadIterSubDims = false;
1762+
}
1763+
#ifndef ROCK_DEBUG_ATTENTION_REMOVE_SOFTMAX
1764+
// TODO: Workaround for issue
1765+
// https://github.com/ROCm/rocMLIR-internal/issues/1802 If sumRowBuffer and
1766+
// expMaxDiffRowBuffer are filled with doSwapThreadIterSubDims=true, it does
1767+
// not match with the second GEMM N dimension. Find a good solution to this.
1768+
ldsLayoutCfgNG0.doSwapThreadIterSubDims = false;
1769+
#endif
17571770
FailureOr<VectorDimInfo> maybeVectorDimInfoK =
17581771
getVectorDim(rewriter, loc, inK, elemTypeK, blockSize, gemm0KPerBlock,
17591772
gemm0MPerBlock, gemm0kpack);
@@ -1828,7 +1841,7 @@ struct GridwiseAttentionAccelRewritePattern
18281841
Value accRegBufferGemm0 =
18291842
createBufferForAccelGemmOut(loc, accelParamsGemm0, rewriter);
18301843
// Currently, there is a working assumption that this kernel is meant
1831-
// support fp32/fp16 This should be guranteed by op verifiers.
1844+
// support fp32/fp16/bf16. This should be guranteed by op verifiers.
18321845
Type gemmOutElemType = elemTypeQxK;
18331846
Type softmaxInElemType = elemTypeQxK;
18341847
if (elemTypeQ == rewriter.getI8Type()) {
@@ -1979,12 +1992,13 @@ struct GridwiseAttentionAccelRewritePattern
19791992
if (failed(statusLoadQ)) {
19801993
return failure();
19811994
}
1995+
rewriter.create<LDSBarrierOp>(loc);
19821996

19831997
TypedValue<MemRefType> ldsTileBufferQ = viewBufferAs(
19841998
rewriter, ldsByteBufferQ, vectorTypeOrSelf(elemTypeQ, gemm0kpack));
19851999
loadGemmOperandsFromLDSToRegs(
19862000
rewriter, loc, ldsTileBufferQ, preAccelRegBuffersQ, "n", blockSize,
1987-
gemm0InMPerThread, *accelEmitterPtrGemm0.get(),
2001+
gemm0InNPerThread, *accelEmitterPtrGemm0.get(),
19882002
ldsLayoutCfgNG0.doRotateWithK);
19892003
rewriter.create<GpuDeallocOp>(loc, ldsByteBufferQ);
19902004
}

0 commit comments

Comments
 (0)