2222#include <functional>
2323#include <map>
2424#include <regex>
25+ #include <sstream>
2526#include <stdexcept>
2627
2728const char * llm_type_name(llm_type type) {
@@ -8215,8 +8216,11 @@ struct llm_build_plamo2 : public llm_graph_context {
82158216 // const int64_t n_embd_head = hparams.n_embd_head_v;
82168217 // ggml_tensor * inp_pos = build_inp_pos();
82178218
8219+ // TODO: Cast to f32 is currently required for ggml_get_rows in build_inp_embd
8220+ ggml_tensor * embed_tokens = ggml_cast(ctx0, model.tok_embd, GGML_TYPE_F32);
8221+
82188222 // {n_embd, n_tokens}
8219- ggml_tensor * inpL = build_inp_embd(model.tok_embd );
8223+ ggml_tensor * inpL = build_inp_embd(embed_tokens );
82208224 cb(inpL, "embedding_output", -1);
82218225
82228226 // ensure the memory context is hybrid
@@ -8230,9 +8234,8 @@ struct llm_build_plamo2 : public llm_graph_context {
82308234 // cb(model.layers[il].attn_norm, "attn_norm", il);
82318235
82328236 // pre_mixer_norm
8233- // cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
8234- cur = ggml_rms_norm(ctx0, inpL, hparams.f_norm_rms_eps);
8235- cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
8237+ cb(inpL, "attn_pre_norm_input", il);
8238+ cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
82368239 cb(cur, "attn_pre_norm", il);
82378240
82388241 // check if this layer is Mamba or Attention
0 commit comments