@@ -251,6 +251,8 @@ struct vk_device_struct {
251251 vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
252252 vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
253253 vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
254+ vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
255+ vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
254256 vk_pipeline pipeline_argsort_f32;
255257 vk_pipeline pipeline_sum_rows_f32;
256258 vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
@@ -494,6 +496,10 @@ struct vk_op_rope_push_constants {
494496 float corr_dims[2 ];
495497 float theta_scale;
496498 uint32_t has_ff;
499+ uint32_t ne02;
500+ uint32_t s1;
501+ uint32_t s2;
502+ int32_t sections[4 ];
497503};
498504
499505struct vk_op_soft_max_push_constants {
@@ -2180,13 +2186,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
21802186
21812187 ggml_vk_create_pipeline (device, device->pipeline_rope_norm_f32 , " rope_norm_f32" , rope_norm_f32_len, rope_norm_f32_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
21822188 ggml_vk_create_pipeline (device, device->pipeline_rope_neox_f32 , " rope_neox_f32" , rope_neox_f32_len, rope_neox_f32_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
2189+ ggml_vk_create_pipeline (device, device->pipeline_rope_multi_f32 , " rope_multi_f32" , rope_multi_f32_len, rope_multi_f32_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
2190+ ggml_vk_create_pipeline (device, device->pipeline_rope_vision_f32 , " rope_vision_f32" , rope_vision_f32_len, rope_vision_f32_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
21832191
21842192 if (device->float_controls_rte_fp16 ) {
21852193 ggml_vk_create_pipeline (device, device->pipeline_rope_norm_f16 , " rope_norm_f16" , rope_norm_f16_rte_len, rope_norm_f16_rte_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
21862194 ggml_vk_create_pipeline (device, device->pipeline_rope_neox_f16 , " rope_neox_f16" , rope_neox_f16_rte_len, rope_neox_f16_rte_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
2195+ ggml_vk_create_pipeline (device, device->pipeline_rope_multi_f16 , " rope_multi_f16" , rope_multi_f16_rte_len, rope_multi_f16_rte_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
2196+ ggml_vk_create_pipeline (device, device->pipeline_rope_vision_f16 , " rope_vision_f16" , rope_vision_f16_rte_len, rope_vision_f16_rte_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
21872197 } else {
21882198 ggml_vk_create_pipeline (device, device->pipeline_rope_norm_f16 , " rope_norm_f16" , rope_norm_f16_len, rope_norm_f16_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
21892199 ggml_vk_create_pipeline (device, device->pipeline_rope_neox_f16 , " rope_neox_f16" , rope_neox_f16_len, rope_neox_f16_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
2200+ ggml_vk_create_pipeline (device, device->pipeline_rope_multi_f16 , " rope_multi_f16" , rope_multi_f16_len, rope_multi_f16_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
2201+ ggml_vk_create_pipeline (device, device->pipeline_rope_vision_f16 , " rope_vision_f16" , rope_vision_f16_len, rope_vision_f16_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
21902202 }
21912203
21922204 ggml_vk_create_pipeline (device, device->pipeline_argsort_f32 , " argsort_f32" , argsort_f32_len, argsort_f32_data, " main" , 2 , sizeof (vk_op_argsort_push_constants), {1024 , 1 , 1 }, {}, 1 );
@@ -5307,6 +5319,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
53075319 {
53085320 const int mode = ((const int32_t *) dst->op_params )[2 ];
53095321 const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
5322+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
5323+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
53105324
53115325 if (is_neox) {
53125326 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
@@ -5315,6 +5329,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
53155329 if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
53165330 return ctx->device ->pipeline_rope_neox_f16 ;
53175331 }
5332+ } else if (is_mrope && !is_vision) {
5333+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5334+ return ctx->device ->pipeline_rope_multi_f32 ;
5335+ }
5336+ if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
5337+ return ctx->device ->pipeline_rope_multi_f16 ;
5338+ }
5339+ } else if (is_vision) {
5340+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5341+ return ctx->device ->pipeline_rope_vision_f32 ;
5342+ }
5343+ if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
5344+ return ctx->device ->pipeline_rope_vision_f16 ;
5345+ }
53185346 } else {
53195347 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
53205348 return ctx->device ->pipeline_rope_norm_f32 ;
@@ -5385,6 +5413,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
53855413 case GGML_OP_CLAMP:
53865414 case GGML_OP_PAD:
53875415 case GGML_OP_REPEAT:
5416+ case GGML_OP_ROPE:
53885417 return true ;
53895418 default :
53905419 return false ;
@@ -6149,7 +6178,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
61496178
61506179static void ggml_vk_rope (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false ) {
61516180 const int n_dims = ((int32_t *) dst->op_params )[1 ];
6152- // const int mode = ((int32_t *) dst->op_params)[2];
6181+ const int mode = ((int32_t *) dst->op_params )[2 ];
61536182 // const int n_ctx = ((int32_t *) dst->op_params)[3];
61546183 const int n_ctx_orig = ((int32_t *) dst->op_params )[4 ];
61556184 const float freq_base = ((float *) dst->op_params )[5 ];
@@ -6158,16 +6187,24 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
61586187 const float attn_factor = ((float *) dst->op_params )[8 ];
61596188 const float beta_fast = ((float *) dst->op_params )[9 ];
61606189 const float beta_slow = ((float *) dst->op_params )[10 ];
6190+ int sections[4 ] {};
6191+ if (mode & GGML_ROPE_TYPE_MROPE) {
6192+ memcpy (sections, (int32_t *) dst->op_params + 11 , sizeof (int )*4 );
6193+ }
61616194
61626195 float corr_dims[2 ];
61636196 ggml_rope_yarn_corr_dims (n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
61646197
61656198 const float theta_scale = powf (freq_base, -2 .0f /n_dims);
61666199
6200+ uint32_t s1 = src0->nb [1 ] / ggml_type_size (src0->type );
6201+ uint32_t s2 = src0->nb [2 ] / ggml_type_size (src0->type );
6202+
61676203 ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, {
61686204 (uint32_t )src0->ne [0 ], (uint32_t )n_dims, freq_scale, (uint32_t )src0->ne [1 ],
61696205 freq_base, ext_factor, attn_factor, {corr_dims[0 ], corr_dims[1 ]}, theta_scale,
6170- src2 != nullptr ,
6206+ src2 != nullptr , (uint32_t )src0->ne [2 ], s1, s2,
6207+ sections[0 ], sections[1 ], sections[2 ], sections[3 ],
61716208 }, dryrun);
61726209}
61736210
@@ -8264,16 +8301,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
82648301 case GGML_OP_REPEAT:
82658302 return ggml_type_size (op->type ) == sizeof (float ) && ggml_type_size (op->src [0 ]->type ) == sizeof (float );
82668303 case GGML_OP_ROPE:
8267- {
8268- const int mode = ((const int32_t *) op->op_params )[2 ];
8269- if (mode & GGML_ROPE_TYPE_MROPE) {
8270- return false ;
8271- }
8272- if (mode & GGML_ROPE_TYPE_VISION) {
8273- return false ;
8274- }
8275- return ggml_is_contiguous (op->src [0 ]);
8276- }
82778304 case GGML_OP_NONE:
82788305 case GGML_OP_RESHAPE:
82798306 case GGML_OP_VIEW:
@@ -8831,7 +8858,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
88318858 const float attn_factor = ((float *) tensor->op_params )[8 ];
88328859 const float beta_fast = ((float *) tensor->op_params )[9 ];
88338860 const float beta_slow = ((float *) tensor->op_params )[10 ];
8834- tensor_clone = ggml_rope_ext (ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
8861+ if (mode & GGML_ROPE_TYPE_MROPE) {
8862+ int32_t *sections = ((int32_t *) tensor->op_params ) + 11 ;
8863+ tensor_clone = ggml_rope_multi (ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
8864+ } else {
8865+ tensor_clone = ggml_rope_ext (ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
8866+ }
88358867 } else if (tensor->op == GGML_OP_UNARY) {
88368868 switch (ggml_get_unary_op (tensor)) {
88378869 case GGML_UNARY_OP_SILU:
0 commit comments