@@ -556,15 +556,14 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
556556}
557557
558558// implementation of the 2D RoPE without adding a new op in ggml
559+ // this is not efficient (use double the memory), but works on all backends
559560static ggml_tensor * build_rope_2d (
560- ggml_cgraph * gf,
561561 ggml_context * ctx0,
562562 ggml_tensor * cur,
563563 ggml_tensor * pos_h,
564564 ggml_tensor * pos_w,
565565 const float freq_base
566566) {
567- ggml_tensor * tmp;
568567 const int64_t n_dim = cur->ne [0 ];
569568 const int64_t n_head = cur->ne [1 ];
570569 const int64_t n_pos = cur->ne [2 ];
@@ -573,18 +572,24 @@ static ggml_tensor * build_rope_2d(
573572 // we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3
574573 // first half of cur will use 1e-0, 1e-2 (even)
575574 // second half of cur will use 1e-1, 1e-3 (odd)
576- //
577- // for the first half, the trick here is to rotate n_dim/2, so inv_freq will be even
575+ // the trick here is to rotate just half of n_dim, so inv_freq will automatically be even
578576 // ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
579577 // then for the second half, we use freq_scale to shift the inv_freq
580578 // ^ why? replace (2i) with (2i+1) in the above equation
581579 const float freq_scale_odd = std::pow (freq_base, (float )-2 /n_dim);
582580
583581 // first half
582+ ggml_tensor * first;
584583 {
585- cur = ggml_rope_ext_inplace (
584+ first = ggml_view_3d (ctx0, cur,
585+ n_dim/2 , n_head, n_pos,
586+ ggml_row_size (cur->type , n_dim),
587+ ggml_row_size (cur->type , n_dim*n_head),
588+ 0 );
589+ // first = ggml_cont(ctx0, first);
590+ first = ggml_rope_ext (
586591 ctx0,
587- cur ,
592+ first ,
588593 pos_h, // positions
589594 nullptr , // freq factors
590595 n_dim/2 , // n_dims
@@ -593,27 +598,28 @@ static ggml_tensor * build_rope_2d(
593598 );
594599 }
595600
596- // second half
601+ // second half (write to tmp)
602+ ggml_tensor * second = cur;
597603 {
598- tmp = ggml_view_3d (ctx0, cur,
604+ second = ggml_view_3d (ctx0, cur,
599605 n_dim/2 , n_head, n_pos,
600606 ggml_row_size (cur->type , n_dim),
601607 ggml_row_size (cur->type , n_dim*n_head),
602608 n_dim/2 * ggml_element_size (cur));
603- tmp = ggml_rope_ext_inplace (
609+ second = ggml_cont (ctx0, second); // copy, because ggml_rope don't play well with non-contiguous tensors
610+ second = ggml_rope_ext (
604611 ctx0,
605- tmp ,
612+ second ,
606613 pos_w, // positions
607614 nullptr , // freq factors
608615 n_dim/2 , // n_dims
609616 0 , 0 , freq_base,
610617 freq_scale_odd,
611618 0 .0f , 1 .0f , 0 .0f , 0 .0f
612619 );
613- // calculate inplace (modify cur directly)
614- ggml_build_forward_expand (gf, tmp);
615620 }
616621
622+ cur = ggml_concat (ctx0, first, second, 0 );
617623 return cur;
618624}
619625
@@ -682,13 +688,13 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
682688 struct ggml_tensor * Q = ggml_mul_mat (ctx0, model.layers [il].q_w , cur);
683689
684690 Q = ggml_reshape_3d (ctx0, Q, d_head, n_head, num_patches);
685- Q = build_rope_2d (gf, ctx0, Q, pos_h, pos_w, hparams.rope_theta );
691+ Q = build_rope_2d (ctx0, Q, pos_h, pos_w, hparams.rope_theta );
686692 Q = ggml_cont (ctx0, ggml_permute (ctx0, Q, 0 , 2 , 1 , 3 ));
687693
688694 struct ggml_tensor * K = ggml_mul_mat (ctx0, model.layers [il].k_w , cur);
689695
690696 K = ggml_reshape_3d (ctx0, K, d_head, n_head, num_patches);
691- K = build_rope_2d (gf, ctx0, K, pos_h, pos_w, hparams.rope_theta );
697+ K = build_rope_2d (ctx0, K, pos_h, pos_w, hparams.rope_theta );
692698 K = ggml_cont (ctx0, ggml_permute (ctx0, K, 0 , 2 , 1 , 3 ));
693699
694700 struct ggml_tensor * V = ggml_mul_mat (ctx0, model.layers [il].v_w , cur);
0 commit comments