-
Notifications
You must be signed in to change notification settings - Fork 12.7k
WIP: ggml-cuda: Add bf16 cuda support to fattn (Flash Attention) #15261
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -649,10 +649,52 @@ static void convert_unary_cont_cuda(const void * vx, dst_t * y, const int64_t k, | |
|
||
to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) { | ||
switch (type) { | ||
case GGML_TYPE_F32: | ||
return convert_unary_cont_cuda<float>; | ||
case GGML_TYPE_Q4_0: | ||
return dequantize_row_q4_0_cuda; | ||
case GGML_TYPE_Q4_1: | ||
return dequantize_row_q4_1_cuda; | ||
case GGML_TYPE_Q5_0: | ||
return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>; | ||
case GGML_TYPE_Q5_1: | ||
return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>; | ||
case GGML_TYPE_Q8_0: | ||
return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>; | ||
case GGML_TYPE_Q2_K: | ||
return dequantize_row_q2_K_cuda; | ||
case GGML_TYPE_Q3_K: | ||
return dequantize_row_q3_K_cuda; | ||
case GGML_TYPE_Q4_K: | ||
return dequantize_row_q4_K_cuda; | ||
case GGML_TYPE_Q5_K: | ||
return dequantize_row_q5_K_cuda; | ||
case GGML_TYPE_Q6_K: | ||
return dequantize_row_q6_K_cuda; | ||
case GGML_TYPE_IQ2_XXS: | ||
return dequantize_row_iq2_xxs_cuda; | ||
case GGML_TYPE_IQ2_XS: | ||
return dequantize_row_iq2_xs_cuda; | ||
case GGML_TYPE_IQ2_S: | ||
return dequantize_row_iq2_s_cuda; | ||
case GGML_TYPE_IQ3_XXS: | ||
return dequantize_row_iq3_xxs_cuda; | ||
case GGML_TYPE_IQ1_S: | ||
return dequantize_row_iq1_s_cuda; | ||
case GGML_TYPE_IQ1_M: | ||
return dequantize_row_iq1_m_cuda; | ||
case GGML_TYPE_IQ4_NL: | ||
return dequantize_row_iq4_nl_cuda; | ||
case GGML_TYPE_IQ4_XS: | ||
return dequantize_row_iq4_xs_cuda; | ||
case GGML_TYPE_IQ3_S: | ||
return dequantize_row_iq3_s_cuda; | ||
case GGML_TYPE_MXFP4: | ||
return dequantize_row_mxfp4_cuda; | ||
case GGML_TYPE_F16: | ||
return convert_unary_cont_cuda<half>; | ||
case GGML_TYPE_BF16: | ||
return convert_unary_cont_cuda<nv_bfloat16>; | ||
case GGML_TYPE_F32: | ||
return convert_unary_cont_cuda<float>; | ||
Comment on lines
+652
to
+697
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please keep the order or types the same as their definition in the enum. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we talking about the ggml_type enum? None of the three types are in the same order. Should I update them all to be consistent? My change was just copying the f32 type $ diff bf16 f32
46,47d45
< case GGML_TYPE_F32:
< return convert_unary_cont_cuda<float>; and adding the ggml_type_f32 case. Happy to fix up the cases as appropriate, just let me know how you want to proceed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Regardless of what order types were in before, when you do a switch like this, please use the order as declared in the |
||
default: | ||
return nullptr; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will break CUDA 11 compatibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was some hackery due to an issue I was having getting the F32 kernel device code compiled and wasn't confident at the time cmake was respecting my
-DCMAKE_CUDA_ARCHITECTURES
flag. I reverted it instead of adjusting it since its not relevant to the change.