@@ -2317,9 +2317,9 @@ getAttentionDimNames(SmallVectorImpl<SmallVector<StringRef>> &result,
23172317 else
23182318 result.emplace_back (SmallVector<StringRef>{gName , seqQName, headQKName});
23192319 if (transposeK)
2320- result.emplace_back (SmallVector<StringRef>{gName , headQKName, seqKName});
2321- else
23222320 result.emplace_back (SmallVector<StringRef>{gName , seqKName, headQKName});
2321+ else
2322+ result.emplace_back (SmallVector<StringRef>{gName , headQKName, seqKName});
23232323 if (transposeV)
23242324 result.emplace_back (SmallVector<StringRef>{gName , headVName, seqKName});
23252325 else
@@ -2369,9 +2369,8 @@ Value addTensorArgToBlock(OpBuilder &builder, Location loc,
23692369 return funcArgTensor;
23702370}
23712371
2372- template <typename T>
23732372static Value maskKVCacheTosa (OpBuilder builder, Location loc, Value inputTensor,
2374- Value currentSeqLenVal, T initValue) {
2373+ Value currentSeqLenVal, float initValue) {
23752374 // inputTensor is [B*NUM_HEADS, SEQ_LEN_Q, SEQ_LEN_KV], we want to reshape to
23762375 // [B, NUM_HEADS, SEQ_LEN_Q, SEQ_LEN_KV]
23772376 auto origType = cast<RankedTensorType>(inputTensor.getType ());
@@ -2423,28 +2422,16 @@ static Value maskKVCacheTosa(OpBuilder builder, Location loc, Value inputTensor,
24232422 currentSeqLenBroadcast);
24242423
24252424 // create a tensor with a single value and broadcast it
2426- DenseElementsAttr initValueAttr;
2427- if constexpr (std::is_same_v<T, int32_t >) {
2428- assert (inpType.getElementType () == builder.getI32Type ());
2429- initValueAttr = DenseIntElementsAttr::get (
2430- RankedTensorType::get (inpShape, inpType.getElementType ()), initValue);
2431- } else if constexpr (std::is_same_v<T, float >) {
2432- assert (inpType.getElementType () == builder.getF32Type () ||
2433- inpType.getElementType () == builder.getF16Type ());
2434- llvm::APFloat fpVal (initValue);
2435- if (inpType.getElementType () == builder.getF16Type ()) {
2436- bool losesInfo = false ;
2437- auto status =
2438- fpVal.convert (llvm::APFloat::IEEEhalf (),
2439- llvm::APFloat::rmNearestTiesToEven, &losesInfo);
2440- assert (status == llvm::APFloat::opOK);
2441- }
2442- initValueAttr = DenseFPElementsAttr::get (
2443- RankedTensorType::get (inpShape, inpType.getElementType ()), fpVal);
2444- } else {
2445- static_assert (!std::is_same_v<T, T>,
2446- " Unsupported type for MLIR type mapping" );
2447- }
2425+ assert (isa<FloatType>(inpType.getElementType ()));
2426+ std::pair<APFloat, llvm::detail::opStatus> floatRes =
2427+ rock::createAPFloat (inpType.getElementType (), initValue);
2428+ APFloat fpVal = floatRes.first ;
2429+ auto status = floatRes.second ;
2430+ assert (status == APFloat::opOK);
2431+
2432+ DenseElementsAttr initValueAttr = DenseFPElementsAttr::get (
2433+ RankedTensorType::get (inpShape, inpType.getElementType ()), fpVal);
2434+
24482435 Value initVal = builder.create <tosa::ConstOp>(loc, initValueAttr.getType (),
24492436 initValueAttr);
24502437
@@ -2809,6 +2796,18 @@ static Value transposeMatrix(OpBuilder &builder, Location loc, Value src,
28092796 return createOpAndInfer<tosa::TransposeOp>(builder, loc, elemType, src, perm);
28102797}
28112798
2799+ static Type getAccType (Type inputType, OpBuilder builder) {
2800+ Type accType;
2801+ if (isa<FloatType>(inputType)) {
2802+ accType = builder.getF32Type ();
2803+ } else if (isa<IntegerType>(inputType)) {
2804+ accType = builder.getI32Type ();
2805+ } else {
2806+ llvm_unreachable (" not expected type" );
2807+ }
2808+ return accType;
2809+ }
2810+
28122811static func::FuncOp createCpuAttentionKernelWithMlir (ModuleOp module ,
28132812 const GenParams ¶ms) {
28142813 MLIRContext *ctx = module .getContext ();
@@ -2880,9 +2879,18 @@ static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module,
28802879 auto keysZp =
28812880 tosa::createZeroPointTensor (builder, loc, keysTensor.getType (), 0 )
28822881 .value ();
2883- Value qkTensor = createOpAndInfer<tosa::MatMulOp>(
2884- builder, loc, firstGemmOutElemType, queriesTensor, keysTensor, queriesZp,
2885- keysZp);
2882+ // TODO: if/when tosa::matmul has acc_type implemented, we can use it here to
2883+ // be more similar to what the gpu code does
2884+ // accumulate in 32 bit
2885+ Type firstAccType = getAccType (firstGemmOutElemType, builder);
2886+ assert (firstAccType == getAccType (params.types [1 ], builder));
2887+ Value qkTensorBeforeConversion = createOpAndInfer<tosa::MatMulOp>(
2888+ builder, loc, firstAccType, queriesTensor, keysTensor, queriesZp, keysZp);
2889+ Value qkTensor = builder.createOrFold <tosa::CastOp>(
2890+ loc,
2891+ cast<ShapedType>(qkTensorBeforeConversion.getType ())
2892+ .clone (firstGemmOutElemType),
2893+ qkTensorBeforeConversion);
28862894
28872895 // get currentSeqLenTensor
28882896 Value currentSeqLenTensor;
@@ -2995,9 +3003,19 @@ static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module,
29953003 auto valuesZp =
29963004 tosa::createZeroPointTensor (builder, loc, valuesTensor.getType (), 0 )
29973005 .value ();
2998- Value resultTensor = createOpAndInfer<tosa::MatMulOp>(
2999- builder, loc, resultOutElementType, softmaxTensor, valuesTensor,
3000- softmaxZp, valuesZp);
3006+
3007+ // TODO: if/when tosa::matmul has acc_type implemented, we can use it here to
3008+ // be more similar to what the gpu code does
3009+ // accumulate in 32 bit
3010+ Type secondAccType = getAccType (resultOutElementType, builder);
3011+ Value resultTensorBeforeConversion = createOpAndInfer<tosa::MatMulOp>(
3012+ builder, loc, secondAccType, softmaxTensor, valuesTensor, softmaxZp,
3013+ valuesZp);
3014+ Value resultTensor = builder.createOrFold <tosa::CastOp>(
3015+ loc,
3016+ cast<ShapedType>(resultTensorBeforeConversion.getType ())
3017+ .clone (resultOutElementType),
3018+ resultTensorBeforeConversion);
30013019
30023020 if (transposeO) {
30033021 resultTensor = transposeMatrix (builder, loc, resultTensor, {0 , 2 , 1 });
0 commit comments