@@ -788,20 +788,23 @@ static void ggml_vk_soft_max(
788788 const std::shared_ptr<kp::Tensor>& out,
789789 uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
790790 int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
791- float scale
791+ float scale, float max_bias, float m0, float m1,
792+ uint32_t n_head_log2
792793) {
793794 const static auto spirv = getSpirvShader (kp::shader_data::op_softmax_comp_spv,
794795 kp::shader_data::op_softmax_comp_spv_len);
795796
796797 struct PushConstants {
797798 uint32_t inAOff, inBOff, outOff;
798799 int32_t ne00, ne01, ne02;
799- float scale;
800+ float scale, max_bias, m0, m1;
801+ uint32_t n_head_log2;
800802 int32_t mask;
801803 } pushConsts {
802804 safe_divide (inAOff, 4 ), safe_divide (inBOff, 4 ), safe_divide (outOff, 4 ),
803805 ne00, ne01, ne02,
804- scale,
806+ scale, max_bias, m0, m1,
807+ n_head_log2,
805808 bool (inB)
806809 };
807810
@@ -1597,11 +1600,16 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
15971600#pragma message("ref: https:// github.com/ggerganov/llama.cpp/pull/5021")
15981601 GGML_ASSERT (!src1 || src1t == GGML_TYPE_F32);
15991602
1600- #pragma message("TODO: add ALiBi support")
1601- #pragma message("ref: https:// github.com/ggerganov/llama.cpp/pull/7192")
1602- GGML_ASSERT (max_bias == 0 .0f );
1603+ const int64_t nrows_x = ggml_nrows (src0);
1604+ const int64_t nrows_y = src0->ne [1 ];
1605+
1606+ const uint32_t n_head = nrows_x/nrows_y;
1607+ const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
1608+
1609+ const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
1610+ const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
16031611
1604- ggml_vk_soft_max (seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
1612+ ggml_vk_soft_max (seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale, max_bias, m0, m1, n_head_log2 );
16051613 } break ;
16061614 case GGML_OP_DIAG_MASK_INF:
16071615 {
0 commit comments