@@ -105,7 +105,7 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
105105 {
106106 auto transformer = get_typed_transformer<ModelClass>();
107107
108- loader. read_tensor (" model.embed_tokens.weight " , transformer-> word_embeddings . weight );
108+ transformer-> word_embeddings -> load (" model.embed_tokens." , &loader );
109109 for (int i = 0 ; i < config.num_hidden_layers ; i++)
110110 {
111111 std::string layer_prefix = " model.layers." + std::to_string (layer_ids[i]) + ' .' ;
@@ -120,7 +120,7 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
120120 loader.read_tensor (layer_prefix + " self_attn.q_proj.weight" , transformer->layers [i].attention .q_proj .weight );
121121 loader.read_tensor (layer_prefix + " self_attn.v_proj.weight" , transformer->layers [i].attention .v_proj .weight );
122122 }
123- loader. read_tensor (" model.norm.weight " , transformer-> final_layernorm . weight );
123+ transformer-> final_layernorm -> load (" model.norm." , &loader );
124124
125125 CHATLLM_CHECK (w_ctx_.get_used_mem () == w_ctx_.get_mem_size ())
126126 << " corrupted model weights" ;
@@ -257,7 +257,7 @@ template <class Layer> static void load_layer(ModelLoader &loader, const std::st
257257class ConditionalGeneration : public BaseModelForConditionalGeneration
258258{
259259public:
260- typedef HeterogeneousModel<BaseConfig, Embedding, RMSNorm> ModelClass;
260+ typedef HeterogeneousModel ModelClass;
261261public:
262262 ConditionalGeneration (const Config &config, const RuntimeConfig &runtime_config, ModelType type = MODEL_TYPE_GEMMA2)
263263 : BaseModelForConditionalGeneration(type, config, runtime_config), config(config),
@@ -286,7 +286,11 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
286286 }
287287 };
288288
289- transformer = new ModelClass (&w_ctx_, config, nullptr , create_layer);
289+ transformer = new ModelClass (&w_ctx_, config.num_hidden_layers , config.hidden_size ,
290+ create_embedding<Embedding>(&w_ctx_, config),
291+ create_final_norm<RMSNorm>(&w_ctx_, config),
292+ nullptr ,
293+ create_layer);
290294
291295 get_typed_transformer<ModelClass>()->logits_pp = &logits_pp;
292296
@@ -309,7 +313,7 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
309313 {
310314 auto transformer = get_typed_transformer<ModelClass>();
311315
312- loader. read_tensor (" model.embed_tokens.weight " , transformer-> word_embeddings . weight );
316+ transformer-> word_embeddings -> load (" model.embed_tokens." , &loader );
313317 for (int i = 0 ; i < config.num_hidden_layers ; i++)
314318 {
315319 std::string layer_prefix = " model.layers." + std::to_string (layer_ids[i]) + ' .' ;
@@ -322,7 +326,7 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
322326 load_layer<Gemma2FullBlock>(loader, layer_prefix, transformer->get_layer (i));
323327 }
324328 }
325- loader. read_tensor (" model.norm.weight " , transformer-> final_layernorm . weight );
329+ transformer-> final_layernorm -> load (" model.norm." , &loader );
326330
327331 CHATLLM_CHECK (w_ctx_.get_used_mem () == w_ctx_.get_mem_size ())
328332 << " corrupted model weights" ;
@@ -763,7 +767,7 @@ template <class Layer> static void setup_layer(Block *block, const Config &confi
763767class ConditionalGeneration : public BaseModelForConditionalGeneration
764768{
765769public:
766- typedef HeterogeneousModel<BaseConfig, Embedding, RMSNorm> ModelClass;
770+ typedef HeterogeneousModel ModelClass;
767771public:
768772 ConditionalGeneration (const Config &config, const RuntimeConfig &runtime_config, ModelType type = MODEL_TYPE_GEMMA3)
769773 : BaseModelForConditionalGeneration(type, config, runtime_config, 4096 * 2 ), config(config),
@@ -804,7 +808,10 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
804808 BlockParams::PadEmbedding padding (1024 , 1024 ); // 4 media_emb
805809 _chat_encoder.MAX_PATCH_NUM = padding.get ();
806810
807- transformer = new ModelClass (&w_ctx_, config, nullptr , create_layer);
811+ transformer = new ModelClass (&w_ctx_, config.num_hidden_layers , config.hidden_size ,
812+ create_embedding<Embedding>(&w_ctx_, config),
813+ create_final_norm<RMSNorm>(&w_ctx_, config),
814+ nullptr , create_layer);
808815
809816 for (int i = 0 ; i < config.num_hidden_layers ; i++)
810817 {
@@ -856,12 +863,12 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
856863 void before_generate (const GenerationConfig &gen_config) override
857864 {
858865 std::vector<uint8_t > buf;
859- auto & emb = dynamic_cast <ModelClass *>(transformer)->word_embeddings ;
860- visual.generate (gen_config, dynamic_cast <Tokenizer *>(tokenizer), ggml::type_of (emb. weight ), buf);
866+ auto emb = dynamic_cast <Embedding *>( dynamic_cast < ModelClass *>(transformer)->word_embeddings ) ;
867+ visual.generate (gen_config, dynamic_cast <Tokenizer *>(tokenizer), ggml::type_of (emb-> weight ), buf);
861868 if (buf.size () < 1 ) return ;
862869
863- size_t offset = emb. get_base_nbytes ();
864- Backend::write_tensor_data (emb. weight , buf.data (), offset, buf.size ());
870+ size_t offset = emb-> get_base_nbytes ();
871+ Backend::write_tensor_data (emb-> weight , buf.data (), offset, buf.size ());
865872 }
866873
867874public:
0 commit comments