Skip to content

Commit d888959

Browse files
committed
kompute: softmax: implement ALiBi support
Signed-off-by: Sergio Lopez <[email protected]>
1 parent 913536f commit d888959

File tree

3 files changed

+34
-9
lines changed

3 files changed

+34
-9
lines changed

ggml/src/ggml-kompute/ggml-kompute.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -788,20 +788,23 @@ static void ggml_vk_soft_max(
788788
const std::shared_ptr<kp::Tensor>& out,
789789
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
790790
int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
791-
float scale
791+
float scale, float max_bias, float m0, float m1,
792+
uint32_t n_head_log2
792793
) {
793794
const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv,
794795
kp::shader_data::op_softmax_comp_spv_len);
795796

796797
struct PushConstants {
797798
uint32_t inAOff, inBOff, outOff;
798799
int32_t ne00, ne01, ne02;
799-
float scale;
800+
float scale, max_bias, m0, m1;
801+
uint32_t n_head_log2;
800802
int32_t mask;
801803
} pushConsts {
802804
safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
803805
ne00, ne01, ne02,
804-
scale,
806+
scale, max_bias, m0, m1,
807+
n_head_log2,
805808
bool(inB)
806809
};
807810

@@ -1597,11 +1600,16 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
15971600
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
15981601
GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
15991602

1600-
#pragma message("TODO: add ALiBi support")
1601-
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
1602-
GGML_ASSERT(max_bias == 0.0f);
1603+
const int64_t nrows_x = ggml_nrows(src0);
1604+
const int64_t nrows_y = src0->ne[1];
1605+
1606+
const uint32_t n_head = nrows_x/nrows_y;
1607+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
1608+
1609+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1610+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
16031611

1604-
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
1612+
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale, max_bias, m0, m1, n_head_log2);
16051613
} break;
16061614
case GGML_OP_DIAG_MASK_INF:
16071615
{

ggml/src/ggml-kompute/kompute-shaders/common.comp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
44
#extension GL_EXT_shader_explicit_arithmetic_types_int8: require
55
#extension GL_EXT_shader_explicit_arithmetic_types_int16: require
6+
#extension GL_EXT_shader_explicit_arithmetic_types_int64: require
67
#extension GL_EXT_control_flow_attributes: enable
78
#extension GL_KHR_shader_subgroup_arithmetic : require
89
#extension GL_EXT_debug_printf : enable

ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ layout(push_constant) uniform PushConstants {
1818
int ne01;
1919
int ne02;
2020
float scale;
21+
float max_bias;
22+
float m0;
23+
float m1;
24+
uint n_head_log2;
2125
int mask;
2226
} pcs;
2327

@@ -34,17 +38,29 @@ void main() {
3438
const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB
3539
const uint pdst = extra_off + pcs.outOff; // Based from out_
3640

41+
float slope = 1.0f;
42+
43+
// ALiBi
44+
if (pcs.max_bias > 0.0f) {
45+
int64_t h = i02;
46+
47+
float base = h < pcs.n_head_log2 ? pcs.m0 : pcs.m1;
48+
int64_t exp = h < pcs.n_head_log2 ? h + 1 : 2*(h - pcs.n_head_log2) + 1;
49+
50+
slope = pow(base, float(exp));
51+
}
52+
3753
// parallel max
3854
float localMax = uintBitsToFloat(0xFF800000);
3955
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
40-
localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f));
56+
localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f));
4157
}
4258
float max_ = subgroupMax(localMax);
4359

4460
// parallel sum
4561
float localSum = 0.0f;
4662
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
47-
const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f) - max_);
63+
const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f) - max_);
4864
localSum += exp_psrc0;
4965
out_[pdst + i00] = exp_psrc0;
5066
}

0 commit comments

Comments
 (0)