@@ -158,9 +158,17 @@ static struct ggml_tensor * llm_build_inp_embd(
158158}
159159
160160static struct ggml_tensor * llm_build_cross_embd (
161+ struct ggml_context * ctx,
161162 const llama_ubatch & ubatch
162163) {
163- struct ggml_tensor * cross_embd = ubatch.cross_embd ;
164+ struct ggml_tensor * cross_embd;
165+ if (ubatch.cross_embd ) {
166+ cross_embd = ubatch.cross_embd ;
167+ } else {
168+ printf (" ubatch does not have cross_embd tensor, "
169+ " building graph with placeholder instead\n " );
170+ cross_embd = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, 1024 , 6400 );
171+ }
164172 ggml_set_input (cross_embd);
165173 return cross_embd;
166174}
@@ -727,7 +735,7 @@ static struct ggml_tensor * llm_build_cross_kv(
727735 // H = number of heads
728736 // L = number of tokens
729737 // B = batch size
730- const int64_t num_heads = qcur-> ne [ 1 ] ;
738+ const int64_t num_heads = lctx. model . hparams . n_head () ;
731739 const float cross_attn_scale = 1 .0f / sqrtf (float (qcur->ne [0 ]));
732740 // Only add the computation of K and V if
733741 // the cache doesn't already have the data
@@ -744,10 +752,8 @@ static struct ggml_tensor * llm_build_cross_kv(
744752 // Compute cross attention score
745753 struct ggml_tensor * q = ggml_reshape_4d (ctx, qcur, qcur->ne [0 ] / num_heads,
746754 num_heads, qcur->ne [1 ], qcur->ne [2 ]);
747- k = ggml_reshape_4d (ctx, k, kcur->ne [0 ] / num_heads, num_heads,
748- kcur->ne [1 ], kcur->ne [2 ]);
749- v = ggml_reshape_4d (ctx, v, vcur->ne [0 ] / num_heads, num_heads,
750- vcur->ne [1 ], vcur->ne [2 ]);
755+ k = ggml_reshape_3d (ctx, k, 1024 / num_heads, num_heads, 6400 );
756+ v = ggml_reshape_3d (ctx, v, 1024 / num_heads, num_heads, 6400 );
751757 q = ggml_permute (ctx, q, 0 , 2 , 1 , 3 );
752758 k = ggml_permute (ctx, k, 0 , 2 , 1 , 3 );
753759 v = ggml_permute (ctx, v, 1 , 2 , 0 , 3 );
@@ -8194,7 +8200,7 @@ struct llm_build_context {
81948200
81958201 // Get the cross vision encoder embedded picture
81968202 struct ggml_tensor * cross_embd;
8197- cross_embd = llm_build_cross_embd (ubatch);
8203+ cross_embd = llm_build_cross_embd (ctx0, ubatch);
81988204
81998205 // Assuming text tokens are in ubatch.token, and image tokens are in ubatch.embd_tensor
82008206 bool batch_is_text;
@@ -8310,6 +8316,7 @@ struct llm_build_context {
83108316
83118317 inpSA = ggml_add (ctx0, inpSA, cur);
83128318 }
8319+ lctx.kv_cross .cache_filled = true ;
83138320
83148321 cur = ggml_rms_norm (ctx0, inpSA, hparams.f_norm_rms_eps );
83158322 cur = ggml_mul (ctx0, cur, model.output_norm );
0 commit comments