2828#include " shaderop_getrows_q4_0.h"
2929#include " shaderop_getrows_q4_1.h"
3030#include " shaderop_getrows_q6_k.h"
31- #include " shaderop_rope_f16.h"
32- #include " shaderop_rope_f32.h"
31+ #include " shaderop_rope_norm_f16.h"
32+ #include " shaderop_rope_norm_f32.h"
33+ #include " shaderop_rope_neox_f16.h"
34+ #include " shaderop_rope_neox_f32.h"
3335#include " shaderop_cpy_f16_f16.h"
3436#include " shaderop_cpy_f16_f32.h"
3537#include " shaderop_cpy_f32_f16.h"
@@ -345,7 +347,7 @@ void ggml_vk_allocate_descriptor_pool(struct ggml_kompute_context * ctx, size_t
345347 std::vector<vk::DescriptorPoolSize> descriptorPoolSizes = {
346348 vk::DescriptorPoolSize (
347349 vk::DescriptorType::eStorageBuffer,
348- 3 * size // Descriptor count is number of possible tensors to pass into an algorithm
350+ 4 * size // Descriptor count is number of possible tensors to pass into an algorithm
349351 )
350352 };
351353
@@ -1220,22 +1222,29 @@ static void ggml_vk_rope(
12201222 kp::Sequence& seq,
12211223 const std::shared_ptr<kp::Tensor>& inA,
12221224 const std::shared_ptr<kp::Tensor>& inB,
1225+ const std::shared_ptr<kp::Tensor>& inC,
12231226 const std::shared_ptr<kp::Tensor>& out,
1224- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1227+ uint32_t inAOff, uint32_t inBOff, uint32_t inCOff, uint32_t outOff,
12251228 ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
1226- float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
1229+ float freq_base, float freq_scale, bool has_freq_factors, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
12271230 int32_t ne01, int32_t ne02, int32_t ne03,
12281231 uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
12291232 int32_t ne0,
12301233 uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3
12311234) {
12321235 GGML_ASSERT (src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32);
12331236
1234- static const auto spirv_f16 = getSpirvShader (
1235- kp::shader_data::op_rope_f16_comp_spv, kp::shader_data::op_rope_f16_comp_spv_len
1237+ static const auto spirv_norm_f16 = getSpirvShader (
1238+ kp::shader_data::op_rope_norm_f16_comp_spv, kp::shader_data::op_rope_norm_f16_comp_spv_len
1239+ );
1240+ static const auto spirv_norm_f32 = getSpirvShader (
1241+ kp::shader_data::op_rope_norm_f32_comp_spv, kp::shader_data::op_rope_norm_f32_comp_spv_len
1242+ );
1243+ static const auto spirv_neox_f16 = getSpirvShader (
1244+ kp::shader_data::op_rope_neox_f16_comp_spv, kp::shader_data::op_rope_neox_f16_comp_spv_len
12361245 );
1237- static const auto spirv_f32 = getSpirvShader (
1238- kp::shader_data::op_rope_f32_comp_spv , kp::shader_data::op_rope_f32_comp_spv_len
1246+ static const auto spirv_neox_f32 = getSpirvShader (
1247+ kp::shader_data::op_rope_neox_f32_comp_spv , kp::shader_data::op_rope_neox_f32_comp_spv_len
12391248 );
12401249
12411250 int type_size = src0t == GGML_TYPE_F16 ? 2 : 4 ;
@@ -1250,32 +1259,40 @@ static void ggml_vk_rope(
12501259 GGML_ASSERT (nb0 % type_size == 0 );
12511260
12521261 struct PushConstants {
1253- uint32_t inAOff, inBOff, outOff;
1262+ uint32_t inAOff, inBOff, inCOff, outOff;
12541263 int32_t n_dims, mode, n_ctx_orig;
1255- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1264+ float freq_base, freq_scale;
1265+ bool has_freq_factors;
1266+ float ext_factor, attn_factor, beta_fast, beta_slow;
12561267 uint32_t nb00, nb01, nb02, nb03;
12571268 int32_t ne0;
12581269 uint32_t nb0, nb1, nb2, nb3;
12591270 } pushConsts {
1260- safe_divide (inAOff, type_size), safe_divide (inBOff, 4 ), safe_divide (outOff, type_size),
1271+ safe_divide (inAOff, type_size), safe_divide (inBOff, 4 ), safe_divide (inCOff, type_size), safe_divide ( outOff, type_size),
12611272 n_dims, mode, n_ctx_orig,
1262- freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1273+ freq_base, freq_scale,
1274+ has_freq_factors,
1275+ ext_factor, attn_factor, beta_fast, beta_slow,
12631276 nb00, nb01, nb02, nb03,
12641277 ne0,
12651278 nb0, nb1, nb2, nb3
12661279 };
12671280
1268- auto name = std::string (__func__) + (src0t == GGML_TYPE_F16 ? " _f16" : " _f32" );
1281+ auto & inC_ = inC ? inC : inA;
1282+ const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
1283+ const bool is_f16 = src0t == GGML_TYPE_F16;
1284+
1285+ auto name = std::string (__func__) + (is_neox ? " _neox" : " _norm" ) + (src0t == GGML_TYPE_F16 ? " _f16" : " _f32" );
12691286 std::shared_ptr<kp::Algorithm> s_algo = nullptr ;
12701287 if (!komputeManager ()->hasAlgorithm (name)) {
1288+ auto & spirv = is_neox ? is_f16 ? spirv_neox_f16 : spirv_neox_f32 : is_f16 ? spirv_norm_f16 : spirv_norm_f32;
12711289 s_algo = komputeManager ()->algorithm <float , PushConstants>(
1272- name, s_kompute_context->pool .get (), {inA, inB, out},
1273- src0t == GGML_TYPE_F16 ? spirv_f16 : spirv_f32,
1290+ name, s_kompute_context->pool .get (), {inA, inB, inC_, out}, spirv,
12741291 {unsigned (ne01), unsigned (ne02), unsigned (ne03)}, {}, {pushConsts}
12751292 );
12761293 } else {
12771294 s_algo = komputeManager ()->getAlgorithm (name);
1278- s_algo->setTensors ({inA, inB, out});
1295+ s_algo->setTensors ({inA, inB, inC_, out});
12791296 s_algo->setWorkgroup ({unsigned (ne01), unsigned (ne02), unsigned (ne03)});
12801297 s_algo->setPushConstants <PushConstants>({pushConsts});
12811298 s_algo->updateDescriptors (s_kompute_context->pool .get ());
@@ -1522,9 +1539,11 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
15221539 const static std::shared_ptr<kp::Tensor> nullTensor = nullptr ;
15231540 uint32_t off_src0 = 0 ;
15241541 uint32_t off_src1 = 0 ;
1542+ uint32_t off_src2 = 0 ;
15251543 uint32_t off_dst = 0 ;
15261544 const std::shared_ptr<kp::Tensor>& id_src0 = src0 ? ggml_vk_get_tensor (src0, &off_src0) : nullTensor;
15271545 const std::shared_ptr<kp::Tensor>& id_src1 = src1 ? ggml_vk_get_tensor (src1, &off_src1) : nullTensor;
1546+ const std::shared_ptr<kp::Tensor>& id_src2 = src2 ? ggml_vk_get_tensor (src2, &off_src2) : nullTensor;
15281547 const std::shared_ptr<kp::Tensor>& id_dst = dst ? ggml_vk_get_tensor (dst, &off_dst) : nullTensor;
15291548
15301549 switch (dst->op ) {
@@ -1721,13 +1740,6 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
17211740 } break ;
17221741 case GGML_OP_ROPE:
17231742 {
1724- #pragma message("TODO: implement phi3 frequency factors support")
1725- #pragma message(" https:// github.com/ggerganov/llama.cpp/pull/7225")
1726- GGML_ASSERT (dst->src [2 ] == nullptr && " phi3 frequency factors not implemented yet" );
1727-
1728- #pragma message("TODO: update rope NORM mode to match NEOX mode")
1729- #pragma message(" https:// github.com/ggerganov/llama.cpp/pull/7634")
1730-
17311743 GGML_ASSERT (ne10 == ne02);
17321744 GGML_ASSERT (src0t == dstt);
17331745 // const int n_past = ((int32_t *) dst->op_params)[0];
@@ -1736,6 +1748,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
17361748 // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
17371749 const int n_ctx_orig = ((int32_t *) dst->op_params )[4 ];
17381750
1751+ const bool has_freq_factors = dst->src [2 ] != nullptr ;
1752+
17391753 float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
17401754 memcpy (&freq_base, (int32_t *) dst->op_params + 5 , sizeof (float ));
17411755 memcpy (&freq_scale, (int32_t *) dst->op_params + 6 , sizeof (float ));
@@ -1744,8 +1758,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
17441758 memcpy (&beta_fast, (int32_t *) dst->op_params + 9 , sizeof (float ));
17451759 memcpy (&beta_slow, (int32_t *) dst->op_params + 10 , sizeof (float ));
17461760 ggml_vk_rope (
1747- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
1748- freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1761+ seq, id_src0, id_src1, id_src2, id_dst, off_src0, off_src1, off_src2 , off_dst, src0t, n_dims, mode, n_ctx_orig,
1762+ freq_base, freq_scale, has_freq_factors, ext_factor, attn_factor, beta_fast, beta_slow,
17491763 ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
17501764 );
17511765 } break ;
0 commit comments