2222#include <functional>
2323#include <map>
2424#include <regex>
25+ #include <sstream>
2526#include <stdexcept>
2627
2728const char * llm_type_name(llm_type type) {
@@ -8008,8 +8009,11 @@ struct llm_build_plamo2 : public llm_graph_context {
80088009 // const int64_t n_embd_head = hparams.n_embd_head_v;
80098010 // ggml_tensor * inp_pos = build_inp_pos();
80108011
8012+ // TODO: Cast to f32 is currently required for ggml_get_rows in build_inp_embd
8013+ ggml_tensor * embed_tokens = ggml_cast(ctx0, model.tok_embd, GGML_TYPE_F32);
8014+
80118015 // {n_embd, n_tokens}
8012- ggml_tensor * inpL = build_inp_embd(model.tok_embd );
8016+ ggml_tensor * inpL = build_inp_embd(embed_tokens );
80138017 cb(inpL, "embedding_output", -1);
80148018
80158019 // ensure the memory context is hybrid
@@ -8023,9 +8027,8 @@ struct llm_build_plamo2 : public llm_graph_context {
80238027 // cb(model.layers[il].attn_norm, "attn_norm", il);
80248028
80258029 // pre_mixer_norm
8026- // cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
8027- cur = ggml_rms_norm(ctx0, inpL, hparams.f_norm_rms_eps);
8028- cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
8030+ cb(inpL, "attn_pre_norm_input", il);
8031+ cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
80298032 cb(cur, "attn_pre_norm", il);
80308033
80318034 // check if this layer is Mamba or Attention
0 commit comments