Skip to content

Commit 4a757de

Browse files
authored
Fix attention bugs (swap thread and iter when Q LDS is bypassed and bf16 tests) (#1797)
* Fix some attention bugs: - do now swap thread and iter subdims for Q if we are bypassing LDS - use f32 attention in CPU code - fix bug in maskKVCacheTosa for bf16
1 parent 2f4cb84 commit 4a757de

File tree

4 files changed

+72
-44
lines changed

4 files changed

+72
-44
lines changed

mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,9 @@ struct GridwiseAttentionAccelRewritePattern
868868
return failure();
869869
}
870870
} else {
871+
assert(!ldsLayoutCfg.doSwapThreadIterSubDims &&
872+
"doSwapThreadIterSubDims must be false if the destination buffer "
873+
"is private memory");
871874
accel::AccelEmitterParams accelEmitterParams = accelEmitter.getParams();
872875
int64_t dRepeats = (nonKDimName == "m" ? accelEmitterParams.mRepeats
873876
: accelEmitterParams.nRepeats);
@@ -1754,6 +1757,9 @@ struct GridwiseAttentionAccelRewritePattern
17541757
}
17551758
LDSLayoutConfigDim ldsLayoutCfgNG0 = getLDSLayoutConfigDim(
17561759
elemTypeQ, gemm0kpack, maybeVectorDimInfoQ.value());
1760+
if (doBypassLDSForQ) {
1761+
ldsLayoutCfgNG0.doSwapThreadIterSubDims = false;
1762+
}
17571763
FailureOr<VectorDimInfo> maybeVectorDimInfoK =
17581764
getVectorDim(rewriter, loc, inK, elemTypeK, blockSize, gemm0KPerBlock,
17591765
gemm0MPerBlock, gemm0kpack);

mlir/test/e2e/PrAttentionBF16.toml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,33 +51,33 @@ config = "-seq_len_q 384 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -perf_con
5151
config = "-seq_len_q 384 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -perf_config attn:v1:32,32,64,8,16,16,8,1"
5252

5353
# check scale
54-
#[[suite.test]]
55-
#config = "-seq_len_q 384 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale"
54+
[[suite.test]]
55+
config = "-seq_len_q 384 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale"
5656

5757
# check bias
5858
[[suite.test]]
5959
config = "-seq_len_q 384 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-bias"
6060

6161
# check scale and bias together
62-
#[[suite.test]]
63-
#config = "-seq_len_q 384 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale --with-attn-bias"
62+
[[suite.test]]
63+
config = "-seq_len_q 384 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale --with-attn-bias"
6464

6565
# cross attention
6666
[[suite.test]]
6767
config = "-seq_len_q 128 -seq_len_k 27 -head_dim_qk 64 -head_dim_v 32 --with-attn-scale --with-attn-bias"
6868

6969
# issue 1661
70-
#[[suite.test]]
71-
#config = "-seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale --with-attn-bias"
70+
[[suite.test]]
71+
config = "-seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale --with-attn-bias"
7272

7373
# GQA
74-
#[[suite.test]]
75-
#config = "-num_heads_q 4 -num_heads_kv 2 -seq_len_q 384 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale --with-attn-bias"
74+
[[suite.test]]
75+
config = "-num_heads_q 4 -num_heads_kv 2 -seq_len_q 384 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale --with-attn-bias"
7676

7777
# GQA + KV Cache
78-
#[[suite.test]]
79-
#config = "-rand 1 -current_seq_len=17 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale --with-attn-bias"
78+
[[suite.test]]
79+
config = "-rand 1 -current_seq_len=17 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale --with-attn-bias"
8080

8181
# GQA + KV Cache batch=3
82-
#[[suite.test]]
83-
#config = "-rand 1 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale --with-attn-bias"
82+
[[suite.test]]
83+
config = "-rand 1 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale --with-attn-bias"
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
// RUN: rocmlir-gen --arch %arch --operation attention -t f16 -seq_len_q 8 -seq_len_k 8 -head_dim_qk 8 -head_dim_v 8 -perf_config attn:v1:32,32,64,8,16,16,8,1 --transQ=true --transK=true --transV=false --transO=false -rand 1 -rand_type int -pv -relDiff_threshold 0.02 -RMS_threshold 0.015 | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s
2+
// RUN: rocmlir-gen --arch %arch --operation attention -t f16 -seq_len_q 8 -seq_len_k 8 -head_dim_qk 8 -head_dim_v 8 -perf_config attn:v1:32,32,64,8,32,32,8,1 --transQ=true --transK=true --transV=false --transO=false -rand 1 -rand_type int -pv -relDiff_threshold 0.02 -RMS_threshold 0.015 | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s
3+
4+
// CHECK: [1 1 1]

mlir/tools/rocmlir-gen/rocmlir-gen.cpp

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2317,9 +2317,9 @@ getAttentionDimNames(SmallVectorImpl<SmallVector<StringRef>> &result,
23172317
else
23182318
result.emplace_back(SmallVector<StringRef>{gName, seqQName, headQKName});
23192319
if (transposeK)
2320-
result.emplace_back(SmallVector<StringRef>{gName, headQKName, seqKName});
2321-
else
23222320
result.emplace_back(SmallVector<StringRef>{gName, seqKName, headQKName});
2321+
else
2322+
result.emplace_back(SmallVector<StringRef>{gName, headQKName, seqKName});
23232323
if (transposeV)
23242324
result.emplace_back(SmallVector<StringRef>{gName, headVName, seqKName});
23252325
else
@@ -2369,9 +2369,8 @@ Value addTensorArgToBlock(OpBuilder &builder, Location loc,
23692369
return funcArgTensor;
23702370
}
23712371

2372-
template <typename T>
23732372
static Value maskKVCacheTosa(OpBuilder builder, Location loc, Value inputTensor,
2374-
Value currentSeqLenVal, T initValue) {
2373+
Value currentSeqLenVal, float initValue) {
23752374
// inputTensor is [B*NUM_HEADS, SEQ_LEN_Q, SEQ_LEN_KV], we want to reshape to
23762375
// [B, NUM_HEADS, SEQ_LEN_Q, SEQ_LEN_KV]
23772376
auto origType = cast<RankedTensorType>(inputTensor.getType());
@@ -2423,28 +2422,16 @@ static Value maskKVCacheTosa(OpBuilder builder, Location loc, Value inputTensor,
24232422
currentSeqLenBroadcast);
24242423

24252424
// create a tensor with a single value and broadcast it
2426-
DenseElementsAttr initValueAttr;
2427-
if constexpr (std::is_same_v<T, int32_t>) {
2428-
assert(inpType.getElementType() == builder.getI32Type());
2429-
initValueAttr = DenseIntElementsAttr::get(
2430-
RankedTensorType::get(inpShape, inpType.getElementType()), initValue);
2431-
} else if constexpr (std::is_same_v<T, float>) {
2432-
assert(inpType.getElementType() == builder.getF32Type() ||
2433-
inpType.getElementType() == builder.getF16Type());
2434-
llvm::APFloat fpVal(initValue);
2435-
if (inpType.getElementType() == builder.getF16Type()) {
2436-
bool losesInfo = false;
2437-
auto status =
2438-
fpVal.convert(llvm::APFloat::IEEEhalf(),
2439-
llvm::APFloat::rmNearestTiesToEven, &losesInfo);
2440-
assert(status == llvm::APFloat::opOK);
2441-
}
2442-
initValueAttr = DenseFPElementsAttr::get(
2443-
RankedTensorType::get(inpShape, inpType.getElementType()), fpVal);
2444-
} else {
2445-
static_assert(!std::is_same_v<T, T>,
2446-
"Unsupported type for MLIR type mapping");
2447-
}
2425+
assert(isa<FloatType>(inpType.getElementType()));
2426+
std::pair<APFloat, llvm::detail::opStatus> floatRes =
2427+
rock::createAPFloat(inpType.getElementType(), initValue);
2428+
APFloat fpVal = floatRes.first;
2429+
auto status = floatRes.second;
2430+
assert(status == APFloat::opOK);
2431+
2432+
DenseElementsAttr initValueAttr = DenseFPElementsAttr::get(
2433+
RankedTensorType::get(inpShape, inpType.getElementType()), fpVal);
2434+
24482435
Value initVal = builder.create<tosa::ConstOp>(loc, initValueAttr.getType(),
24492436
initValueAttr);
24502437

@@ -2809,6 +2796,18 @@ static Value transposeMatrix(OpBuilder &builder, Location loc, Value src,
28092796
return createOpAndInfer<tosa::TransposeOp>(builder, loc, elemType, src, perm);
28102797
}
28112798

2799+
static Type getAccType(Type inputType, OpBuilder builder) {
2800+
Type accType;
2801+
if (isa<FloatType>(inputType)) {
2802+
accType = builder.getF32Type();
2803+
} else if (isa<IntegerType>(inputType)) {
2804+
accType = builder.getI32Type();
2805+
} else {
2806+
llvm_unreachable("not expected type");
2807+
}
2808+
return accType;
2809+
}
2810+
28122811
static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module,
28132812
const GenParams &params) {
28142813
MLIRContext *ctx = module.getContext();
@@ -2880,9 +2879,18 @@ static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module,
28802879
auto keysZp =
28812880
tosa::createZeroPointTensor(builder, loc, keysTensor.getType(), 0)
28822881
.value();
2883-
Value qkTensor = createOpAndInfer<tosa::MatMulOp>(
2884-
builder, loc, firstGemmOutElemType, queriesTensor, keysTensor, queriesZp,
2885-
keysZp);
2882+
// TODO: if/when tosa::matmul has acc_type implemented, we can use it here to
2883+
// be more similar to what the gpu code does
2884+
// accumulate in 32 bit
2885+
Type firstAccType = getAccType(firstGemmOutElemType, builder);
2886+
assert(firstAccType == getAccType(params.types[1], builder));
2887+
Value qkTensorBeforeConversion = createOpAndInfer<tosa::MatMulOp>(
2888+
builder, loc, firstAccType, queriesTensor, keysTensor, queriesZp, keysZp);
2889+
Value qkTensor = builder.createOrFold<tosa::CastOp>(
2890+
loc,
2891+
cast<ShapedType>(qkTensorBeforeConversion.getType())
2892+
.clone(firstGemmOutElemType),
2893+
qkTensorBeforeConversion);
28862894

28872895
// get currentSeqLenTensor
28882896
Value currentSeqLenTensor;
@@ -2995,9 +3003,19 @@ static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module,
29953003
auto valuesZp =
29963004
tosa::createZeroPointTensor(builder, loc, valuesTensor.getType(), 0)
29973005
.value();
2998-
Value resultTensor = createOpAndInfer<tosa::MatMulOp>(
2999-
builder, loc, resultOutElementType, softmaxTensor, valuesTensor,
3000-
softmaxZp, valuesZp);
3006+
3007+
// TODO: if/when tosa::matmul has acc_type implemented, we can use it here to
3008+
// be more similar to what the gpu code does
3009+
// accumulate in 32 bit
3010+
Type secondAccType = getAccType(resultOutElementType, builder);
3011+
Value resultTensorBeforeConversion = createOpAndInfer<tosa::MatMulOp>(
3012+
builder, loc, secondAccType, softmaxTensor, valuesTensor, softmaxZp,
3013+
valuesZp);
3014+
Value resultTensor = builder.createOrFold<tosa::CastOp>(
3015+
loc,
3016+
cast<ShapedType>(resultTensorBeforeConversion.getType())
3017+
.clone(resultOutElementType),
3018+
resultTensorBeforeConversion);
30013019

30023020
if (transposeO) {
30033021
resultTensor = transposeMatrix(builder, loc, resultTensor, {0, 2, 1});

0 commit comments

Comments
 (0)