You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
### Description
<!-- Describe your changes. -->
Deepseek-r1 f16 works incorrectly on flash attention path. However,
Deepseek-r1 f32 works correctly on flash attention path. (Both of them
works correctly on non-fa path)
This PR fixes the incorrect result of deepseek-r1 on flash attention
path. It seems that the result of qk and softmax(qk) overflow on f16.
After changing the computation to use f32 instead of f16 on qk and
softmax(qk) and store qk' result as type `float`, the deepseek-r1 model
works correctly.
This PR includes below changes:
1. Add head_idx boundary checking. It's not related with deepseek
fixing. But it's needed for boundary checking since
[ProgramManager::NormalizeDispatchGroupSize](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/webgpu/program_manager.cc#L22-L39)
will dispatch more workgroups based on the algorithm when it exceeds the
`maxComputeWorkgroupsPerDimension`.
2. Use f32 to do the computation of qk and softmax(qk).
3. Store the results of qk and the max/sum as float instead of q's
original type.
0 commit comments