-
Notifications
You must be signed in to change notification settings - Fork 13.3k
metal : FA support F32 K and V #16531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@JohannesGaessler @jeffbolznv Would it be possible to add support for F32 K and V tensors in the respective backends? The issue is that these casts on Lines 1313 to 1326 in 4b2dae3
If we remove the casts, the memory usage should be significantly reduced for this use case. But to remove them, the FA implementation has to support |
It's definitely possible but it will require additional considerations w.r.t. SRAM limits. For the tile kernel what would need to be done is to determine FP16 vs. FP32 use via a template parameter rather than the |
Do the models for which this is relevant use GQA? |
We could also consider making the operations preceding FA write back their data as FP16 in the first place. In terms of performance that would definitely preferable for all CUDA/ROCm GPUs except for Pascal. |
Generally yes. If it would make the implementation simpler, maybe we can treat F32 K and V as just another "quantization" type, where the dequantize function is a cast to F16? |
For CUDA that can definitely be done with comparatively little effort but it would not eliminate the additional memory use, it would just shift it from the compute buffer to the buffer pool in the CUDA backend. |
I think this should be relatively straightforward in the vulkan backend, I'll look into it. This comment is how I'd expect to implement it (we dequantize while loading, so no extra memory usage):
|
Done for Vulkan in #16543 |
Basic CUDA support in #16546 . |
c308925
to
5734546
Compare
f027196
to
252bb89
Compare
* origin/master: (32 commits) metal : FA support F32 K and V and head size = 32 (ggml-org#16531) graph : support cacheless embeddings with FA and iSWA (ggml-org#16528) opencl: fix build targeting CL 2 (ggml-org#16554) CUDA: fix numerical issues in tile FA kernel (ggml-org#16540) ggml : fix build broken with -march=armv9-a on MacOS (ggml-org#16520) CANN: fix CPU memory leak in CANN backend (ggml-org#16549) fix: add remark plugin to render raw HTML as literal text (ggml-org#16505) metal: add support for opt_step_sgd (ggml-org#16539) ggml : fix scalar path for computing norm (ggml-org#16558) CANN: Update several operators to support FP16 data format (ggml-org#16251) metal : add opt_step_adamw and op_sum (ggml-org#16529) webui: remove client-side context pre-check and rely on backend for limits (ggml-org#16506) [SYCL] fix UT fault cases: count-equal, argsort, pad OPs (ggml-org#16521) ci : add Vulkan on Ubuntu with default packages build (ggml-org#16532) common : handle unicode during partial json parsing (ggml-org#16526) common : update presets (ggml-org#16504) ggml : Fix FP16 ELU positive branch (ggml-org#16519) hparams : add check for layer index in is_recurrent (ggml-org#16511) ggml: Correct SVE implementation in ggml_vec_dot_f16_unroll (ggml-org#16518) CUDA: faster tile FA, add oob checks, more HSs (ggml-org#16492) ...
target #16528
Remove K and V casts with cacheless contexts(we should keep the casts for now)Sample command for testing:
llama-embedding -hf ggml-org/bge-small-en-v1.5-Q8_0-GGUF -e -p "$(printf 'hello %.0s' {1..510})" --pooling cls -c 512 -fa on