@@ -345,8 +345,18 @@ struct lora_merge_ctx {
345345 gf = ggml_new_graph (ctx0);
346346 struct ggml_tensor * cur = inp_base;
347347 for (size_t i = 0 ; i < adapters.size (); ++i) {
348- struct ggml_tensor * a_T = ggml_cont (ctx0, ggml_transpose (ctx0, ggml_cast (ctx0, inp_a[i], GGML_TYPE_F32)));
349- struct ggml_tensor * delta = ggml_mul_mat (ctx0, a_T, ggml_cast (ctx0, inp_b[i], GGML_TYPE_F32));
348+ struct ggml_tensor * delta;
349+ bool is_tok_embd = string_starts_with (name_base, " token_embd" );
350+ if (is_tok_embd) {
351+ printf (" %s : detected token embeddings tensor\n " , __func__);
352+ delta = ggml_mul_mat (ctx0,
353+ ggml_cast (ctx0, inp_b[i], GGML_TYPE_F32),
354+ ggml_cast (ctx0, inp_a[i], GGML_TYPE_F32));
355+ } else {
356+ delta = ggml_mul_mat (ctx0,
357+ ggml_cont (ctx0, ggml_transpose (ctx0, ggml_cast (ctx0, inp_a[i], GGML_TYPE_F32))),
358+ ggml_cast (ctx0, inp_b[i], GGML_TYPE_F32));
359+ }
350360 // scale
351361 const float alpha = adapters[i]->alpha ;
352362 const float rank = (float ) inp_b[i]->ne [0 ];
0 commit comments