@@ -130,6 +130,13 @@ llama_kv_cache_unified::llama_kv_cache_unified(
130130
131131 const char * LLAMA_KV_CACHE_DEBUG = getenv (" LLAMA_KV_CACHE_DEBUG" );
132132 debug = LLAMA_KV_CACHE_DEBUG ? atoi (LLAMA_KV_CACHE_DEBUG) : 0 ;
133+
134+ const char * LLAMA_SET_ROWS = getenv (" LLAMA_SET_ROWS" );
135+ supports_set_rows = LLAMA_SET_ROWS ? atoi (LLAMA_SET_ROWS) : 0 ;
136+
137+ if (!supports_set_rows) {
138+ LLAMA_LOG_WARN (" %s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n " , __func__);
139+ }
133140}
134141
135142void llama_kv_cache_unified::clear (bool data) {
@@ -751,15 +758,21 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
751758
752759 auto * k = layers[ikv].k ;
753760
761+ const int64_t n_embd_k_gqa = k->ne [0 ];
754762 const int64_t n_tokens = k_cur->ne [2 ];
755763
756- if (kv_idxs) {
757- return ggml_set_rows (ctx, k, ggml_reshape_2d (ctx, k_cur, k->ne [0 ], n_tokens), kv_idxs);
764+ k_cur = ggml_reshape_2d (ctx, k_cur, k->ne [0 ], n_tokens);
765+
766+ if (kv_idxs && supports_set_rows) {
767+ return ggml_set_rows (ctx, k, k_cur, kv_idxs);
758768 }
759769
770+ // TODO: fallback to old ggml_cpy() method for backwards compatibility
771+ // will be removed when ggml_set_rows() is adopted by all backends
772+
760773 ggml_tensor * k_view = ggml_view_1d (ctx, k,
761- n_tokens*hparams. n_embd_k_gqa (il) ,
762- ggml_row_size (k->type , hparams. n_embd_k_gqa (il) )*head_cur);
774+ n_tokens*n_embd_k_gqa,
775+ ggml_row_size (k->type , n_embd_k_gqa)*head_cur);
763776
764777 return ggml_cpy (ctx, k_cur, k_view);
765778}
@@ -769,37 +782,43 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
769782
770783 auto * v = layers[ikv].v ;
771784
785+ const int64_t n_embd_v_gqa = v->ne [0 ];
772786 const int64_t n_tokens = v_cur->ne [2 ];
773787
774- v_cur = ggml_reshape_2d (ctx, v_cur, hparams.n_embd_v_gqa (il), n_tokens);
775-
776- ggml_tensor * v_view = nullptr ;
788+ v_cur = ggml_reshape_2d (ctx, v_cur, n_embd_v_gqa, n_tokens);
777789
778- if (!v_trans ) {
779- if (kv_idxs ) {
790+ if (kv_idxs && supports_set_rows ) {
791+ if (!v_trans ) {
780792 return ggml_set_rows (ctx, v, v_cur, kv_idxs);
781793 }
782794
783- v_view = ggml_view_1d (ctx, v,
784- n_tokens*hparams.n_embd_v_gqa (il),
785- ggml_row_size (v->type , hparams.n_embd_v_gqa (il))*head_cur);
786- } else {
795+ // note: the V cache is transposed when not using flash attention
787796 v_cur = ggml_transpose (ctx, v_cur);
788797
789- // note: the V cache is transposed when not using flash attention
790- if (kv_idxs) {
791- // the row becomes a single element and we repeat the KV indices d_head times
792- // TODO: this seems not very optimal - can we do something better?
793- v_view = ggml_reshape_3d (ctx, v, 1 , v->ne [1 ], v->ne [0 ]);
798+ // the row becomes a single element and we repeat the KV indices d_head times
799+ ggml_tensor * v_view = ggml_reshape_3d (ctx, v, 1 , v->ne [1 ], v->ne [0 ]);
794800
795- v_cur = ggml_cont_3d (ctx, v_cur, 1 , v_cur->ne [0 ], v_cur->ne [1 ]);
801+ v_cur = ggml_cont_3d (ctx, v_cur, 1 , v_cur->ne [0 ], v_cur->ne [1 ]);
796802
797- kv_idxs = ggml_repeat_4d (ctx, kv_idxs, v_cur->ne [1 ], v_cur->ne [2 ], 1 , 1 );
803+ // TODO: this repeat can be avoided if ggml_set_rows() supports broadcast
804+ kv_idxs = ggml_repeat_4d (ctx, kv_idxs, v_cur->ne [1 ], v_cur->ne [2 ], 1 , 1 );
798805
799- return ggml_set_rows (ctx, v_view, v_cur, kv_idxs);
800- }
806+ return ggml_set_rows (ctx, v_view, v_cur, kv_idxs);
807+ }
808+
809+ // TODO: fallback to old ggml_cpy() method for backwards compatibility
810+ // will be removed when ggml_set_rows() is adopted by all backends
811+
812+ ggml_tensor * v_view = nullptr ;
813+
814+ if (!v_trans) {
815+ v_view = ggml_view_1d (ctx, v,
816+ n_tokens*n_embd_v_gqa,
817+ ggml_row_size (v->type , n_embd_v_gqa)*head_cur);
818+ } else {
819+ v_cur = ggml_transpose (ctx, v_cur);
801820
802- v_view = ggml_view_2d (ctx, v, n_tokens, hparams. n_embd_v_gqa (il) ,
821+ v_view = ggml_view_2d (ctx, v, n_tokens, n_embd_v_gqa,
803822 (v->ne [1 ])*ggml_element_size (v),
804823 (head_cur)*ggml_element_size (v));
805824 }
@@ -808,6 +827,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
808827}
809828
810829void llama_kv_cache_unified::set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur) const {
830+ if (!supports_set_rows) {
831+ return ;
832+ }
833+
811834 const uint32_t n_tokens = ubatch->n_tokens ;
812835
813836 GGML_ASSERT (ggml_backend_buffer_is_host (dst->buffer ));
0 commit comments