@@ -213,8 +213,8 @@ AttentionOperands<GEMMOperandPrecision> create_fp16_backward_precisions()
213213 memory_precisions[AttentionOperand::dQ] = GEMMOperandPrecision::FP16;
214214 memory_precisions[AttentionOperand::dK] = GEMMOperandPrecision::FP16;
215215 memory_precisions[AttentionOperand::dV] = GEMMOperandPrecision::FP16;
216- memory_precisions[AttentionOperand::L] = GEMMOperandPrecision::FP32 ;
217- memory_precisions[AttentionOperand::D] = GEMMOperandPrecision::FP32 ;
216+ memory_precisions[AttentionOperand::L] = GEMMOperandPrecision::FP16 ;
217+ memory_precisions[AttentionOperand::D] = GEMMOperandPrecision::BF16 ;
218218 return memory_precisions;
219219}
220220
@@ -393,7 +393,7 @@ ForwardPipeline create_forward_pipeline(MTL::Device* device, const AttentionCase
393393 bundle.descriptor .Hk = attention.Hk ;
394394 bundle.descriptor .lowPrecisionInputs = true ;
395395 bundle.descriptor .isBF16 = false ;
396- bundle.descriptor .lowPrecisionIntermediates = false ;
396+ bundle.descriptor .lowPrecisionIntermediates = true ;
397397 bundle.descriptor .matrixDimensions = simd::uint3 { attention.R , attention.C , attention.D };
398398 bundle.descriptor .type = AttentionKernelType::forward;
399399 bundle.descriptor .scale = create_scale (attention);
@@ -431,7 +431,7 @@ BackwardPipelines create_backward_pipelines(
431431 bundle.query_descriptor .Hk = attention.Hk ;
432432 bundle.query_descriptor .lowPrecisionInputs = true ;
433433 bundle.query_descriptor .isBF16 = false ;
434- bundle.query_descriptor .lowPrecisionIntermediates = false ;
434+ bundle.query_descriptor .lowPrecisionIntermediates = true ;
435435 bundle.query_descriptor .matrixDimensions = simd::uint3 { attention.R , attention.C , attention.D };
436436 bundle.query_descriptor .type = AttentionKernelType::backwardQuery;
437437 bundle.query_descriptor .scale = create_scale (attention);
@@ -941,7 +941,7 @@ int main(int argc, char** argv)
941941 << " blockC=" << forward_pipeline.kernel ->blockDimensions [1 ]
942942 << " blockD=" << forward_pipeline.kernel ->blockDimensions [2 ]
943943 << " simdgroups=" << forward_pipeline.kernel ->executionSIMDGroups
944- << " lowPrecisionIntermediates=false "
944+ << " lowPrecisionIntermediates=true "
945945 << ' \n ' ;
946946 std::cout << " backward-kernel"
947947 << " queryBlockR=" << backward_pipelines.query_kernel ->blockDimensions [0 ]
0 commit comments