Skip to content

Commit 1b8afa8

Browse files
committed
kompute: rope: implement neox and phi3 support
Signed-off-by: Sergio Lopez <[email protected]>
1 parent d888959 commit 1b8afa8

File tree

9 files changed

+258
-176
lines changed

9 files changed

+258
-176
lines changed

ggml/src/ggml-kompute/CMakeLists.txt

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,10 @@ if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
105105
kompute-shaders/op_getrows_q4_0.comp
106106
kompute-shaders/op_getrows_q4_1.comp
107107
kompute-shaders/op_getrows_q6_k.comp
108-
kompute-shaders/op_rope_f16.comp
109-
kompute-shaders/op_rope_f32.comp
108+
kompute-shaders/op_rope_norm_f16.comp
109+
kompute-shaders/op_rope_norm_f32.comp
110+
kompute-shaders/op_rope_neox_f16.comp
111+
kompute-shaders/op_rope_neox_f32.comp
110112
kompute-shaders/op_cpy_f16_f16.comp
111113
kompute-shaders/op_cpy_f16_f32.comp
112114
kompute-shaders/op_cpy_f32_f16.comp
@@ -139,8 +141,10 @@ if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
139141
shaderop_getrows_q4_0.h
140142
shaderop_getrows_q4_1.h
141143
shaderop_getrows_q6_k.h
142-
shaderop_rope_f16.h
143-
shaderop_rope_f32.h
144+
shaderop_rope_norm_f16.h
145+
shaderop_rope_norm_f32.h
146+
shaderop_rope_neox_f16.h
147+
shaderop_rope_neox_f32.h
144148
shaderop_cpy_f16_f16.h
145149
shaderop_cpy_f16_f32.h
146150
shaderop_cpy_f32_f16.h

ggml/src/ggml-kompute/ggml-kompute.cpp

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@
2828
#include "shaderop_getrows_q4_0.h"
2929
#include "shaderop_getrows_q4_1.h"
3030
#include "shaderop_getrows_q6_k.h"
31-
#include "shaderop_rope_f16.h"
32-
#include "shaderop_rope_f32.h"
31+
#include "shaderop_rope_norm_f16.h"
32+
#include "shaderop_rope_norm_f32.h"
33+
#include "shaderop_rope_neox_f16.h"
34+
#include "shaderop_rope_neox_f32.h"
3335
#include "shaderop_cpy_f16_f16.h"
3436
#include "shaderop_cpy_f16_f32.h"
3537
#include "shaderop_cpy_f32_f16.h"
@@ -345,7 +347,7 @@ void ggml_vk_allocate_descriptor_pool(struct ggml_kompute_context * ctx, size_t
345347
std::vector<vk::DescriptorPoolSize> descriptorPoolSizes = {
346348
vk::DescriptorPoolSize(
347349
vk::DescriptorType::eStorageBuffer,
348-
3 * size // Descriptor count is number of possible tensors to pass into an algorithm
350+
4 * size // Descriptor count is number of possible tensors to pass into an algorithm
349351
)
350352
};
351353

@@ -1220,22 +1222,29 @@ static void ggml_vk_rope(
12201222
kp::Sequence& seq,
12211223
const std::shared_ptr<kp::Tensor>& inA,
12221224
const std::shared_ptr<kp::Tensor>& inB,
1225+
const std::shared_ptr<kp::Tensor>& inC,
12231226
const std::shared_ptr<kp::Tensor>& out,
1224-
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1227+
uint32_t inAOff, uint32_t inBOff, uint32_t inCOff, uint32_t outOff,
12251228
ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
1226-
float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
1229+
float freq_base, float freq_scale, bool has_freq_factors, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
12271230
int32_t ne01, int32_t ne02, int32_t ne03,
12281231
uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
12291232
int32_t ne0,
12301233
uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3
12311234
) {
12321235
GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32);
12331236

1234-
static const auto spirv_f16 = getSpirvShader(
1235-
kp::shader_data::op_rope_f16_comp_spv, kp::shader_data::op_rope_f16_comp_spv_len
1237+
static const auto spirv_norm_f16 = getSpirvShader(
1238+
kp::shader_data::op_rope_norm_f16_comp_spv, kp::shader_data::op_rope_norm_f16_comp_spv_len
1239+
);
1240+
static const auto spirv_norm_f32 = getSpirvShader(
1241+
kp::shader_data::op_rope_norm_f32_comp_spv, kp::shader_data::op_rope_norm_f32_comp_spv_len
1242+
);
1243+
static const auto spirv_neox_f16 = getSpirvShader(
1244+
kp::shader_data::op_rope_neox_f16_comp_spv, kp::shader_data::op_rope_neox_f16_comp_spv_len
12361245
);
1237-
static const auto spirv_f32 = getSpirvShader(
1238-
kp::shader_data::op_rope_f32_comp_spv, kp::shader_data::op_rope_f32_comp_spv_len
1246+
static const auto spirv_neox_f32 = getSpirvShader(
1247+
kp::shader_data::op_rope_neox_f32_comp_spv, kp::shader_data::op_rope_neox_f32_comp_spv_len
12391248
);
12401249

12411250
int type_size = src0t == GGML_TYPE_F16 ? 2 : 4;
@@ -1250,32 +1259,40 @@ static void ggml_vk_rope(
12501259
GGML_ASSERT(nb0 % type_size == 0);
12511260

12521261
struct PushConstants {
1253-
uint32_t inAOff, inBOff, outOff;
1262+
uint32_t inAOff, inBOff, inCOff, outOff;
12541263
int32_t n_dims, mode, n_ctx_orig;
1255-
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1264+
float freq_base, freq_scale;
1265+
bool has_freq_factors;
1266+
float ext_factor, attn_factor, beta_fast, beta_slow;
12561267
uint32_t nb00, nb01, nb02, nb03;
12571268
int32_t ne0;
12581269
uint32_t nb0, nb1, nb2, nb3;
12591270
} pushConsts {
1260-
safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
1271+
safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(inCOff, type_size), safe_divide(outOff, type_size),
12611272
n_dims, mode, n_ctx_orig,
1262-
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1273+
freq_base, freq_scale,
1274+
has_freq_factors,
1275+
ext_factor, attn_factor, beta_fast, beta_slow,
12631276
nb00, nb01, nb02, nb03,
12641277
ne0,
12651278
nb0, nb1, nb2, nb3
12661279
};
12671280

1268-
auto name = std::string(__func__) + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
1281+
auto & inC_ = inC ? inC : inA;
1282+
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
1283+
const bool is_f16 = src0t == GGML_TYPE_F16;
1284+
1285+
auto name = std::string(__func__) + (is_neox ? "_neox" : "_norm") + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
12691286
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
12701287
if (!komputeManager()->hasAlgorithm(name)) {
1288+
auto & spirv = is_neox ? is_f16 ? spirv_neox_f16 : spirv_neox_f32 : is_f16 ? spirv_norm_f16 : spirv_norm_f32;
12711289
s_algo = komputeManager()->algorithm<float, PushConstants>(
1272-
name, s_kompute_context->pool.get(), {inA, inB, out},
1273-
src0t == GGML_TYPE_F16 ? spirv_f16 : spirv_f32,
1290+
name, s_kompute_context->pool.get(), {inA, inB, inC_, out}, spirv,
12741291
{unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}
12751292
);
12761293
} else {
12771294
s_algo = komputeManager()->getAlgorithm(name);
1278-
s_algo->setTensors({inA, inB, out});
1295+
s_algo->setTensors({inA, inB, inC_, out});
12791296
s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
12801297
s_algo->setPushConstants<PushConstants>({pushConsts});
12811298
s_algo->updateDescriptors(s_kompute_context->pool.get());
@@ -1522,9 +1539,11 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
15221539
const static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
15231540
uint32_t off_src0 = 0;
15241541
uint32_t off_src1 = 0;
1542+
uint32_t off_src2 = 0;
15251543
uint32_t off_dst = 0;
15261544
const std::shared_ptr<kp::Tensor>& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor;
15271545
const std::shared_ptr<kp::Tensor>& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor;
1546+
const std::shared_ptr<kp::Tensor>& id_src2 = src2 ? ggml_vk_get_tensor(src2, &off_src2) : nullTensor;
15281547
const std::shared_ptr<kp::Tensor>& id_dst = dst ? ggml_vk_get_tensor(dst, &off_dst) : nullTensor;
15291548

15301549
switch (dst->op) {
@@ -1721,13 +1740,6 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
17211740
} break;
17221741
case GGML_OP_ROPE:
17231742
{
1724-
#pragma message("TODO: implement phi3 frequency factors support")
1725-
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
1726-
GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
1727-
1728-
#pragma message("TODO: update rope NORM mode to match NEOX mode")
1729-
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
1730-
17311743
GGML_ASSERT(ne10 == ne02);
17321744
GGML_ASSERT(src0t == dstt);
17331745
// const int n_past = ((int32_t *) dst->op_params)[0];
@@ -1736,6 +1748,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
17361748
// skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
17371749
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
17381750

1751+
const bool has_freq_factors = dst->src[2] != nullptr;
1752+
17391753
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
17401754
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
17411755
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
@@ -1744,8 +1758,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
17441758
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
17451759
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
17461760
ggml_vk_rope(
1747-
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
1748-
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1761+
seq, id_src0, id_src1, id_src2, id_dst, off_src0, off_src1, off_src2, off_dst, src0t, n_dims, mode, n_ctx_orig,
1762+
freq_base, freq_scale, has_freq_factors, ext_factor, attn_factor, beta_fast, beta_slow,
17491763
ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
17501764
);
17511765
} break;

ggml/src/ggml-kompute/kompute-shaders/op_rope_f16.comp

Lines changed: 0 additions & 73 deletions
This file was deleted.

ggml/src/ggml-kompute/kompute-shaders/op_rope_f32.comp

Lines changed: 0 additions & 73 deletions
This file was deleted.
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#version 450
2+
3+
#include "rope_common.comp"
4+
5+
layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; };
6+
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
7+
layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; };
8+
layout(binding = 3) buffer restrict writeonly tensorOut { float16_t out_[]; };
9+
10+
void main() {
11+
const uint i3 = gl_WorkGroupID.z;
12+
const uint i2 = gl_WorkGroupID.y;
13+
const uint i1 = gl_WorkGroupID.x;
14+
15+
float corr_dims[2];
16+
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
17+
18+
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
19+
20+
float theta_base = float(inB[pcs.inBOff + i2]);
21+
float inv_ndims = -1.f/pcs.n_dims;
22+
23+
float cos_theta;
24+
float sin_theta;
25+
26+
for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
27+
if (i0 < pcs.n_dims) {
28+
uint ic = i0/2;
29+
30+
float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
31+
32+
const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
33+
34+
rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
35+
36+
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + ic*pcs.nb00) / 2) + pcs.inAOff; // Based from in
37+
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + ic*pcs.nb0) / 2) + pcs.outOff; // Based from out_
38+
39+
const float x0 = float(inA[src]);
40+
const float x1 = float(inA[src+pcs.n_dims/2]);
41+
42+
out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta);
43+
out_[dst_data+pcs.n_dims/2] = float16_t(x0*sin_theta + x1*cos_theta);
44+
} else {
45+
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
46+
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
47+
48+
out_[dst_data] = inA[src];
49+
out_[dst_data+1] = inA[src+1];
50+
}
51+
}
52+
}

0 commit comments

Comments
 (0)