Skip to content

Commit 251c9c9

Browse files
committed
refactor
1 parent 0180f53 commit 251c9c9

29 files changed

+199
-194
lines changed

models/adept.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -372,12 +372,12 @@ namespace fuyu
372372
void before_generate(const GenerationConfig &gen_config) override
373373
{
374374
std::vector<uint8_t> buf;
375-
auto &emb = dynamic_cast<ModelClass *>(transformer)->word_embeddings;
376-
visual.generate(gen_config, dynamic_cast<Tokenizer *>(tokenizer), ggml::type_of(emb.weight), buf);
375+
auto emb = dynamic_cast<Embedding *>(dynamic_cast<ModelClass *>(transformer)->word_embeddings);
376+
visual.generate(gen_config, dynamic_cast<Tokenizer *>(tokenizer), ggml::type_of(emb->weight), buf);
377377
if (buf.size() < 1) return;
378378

379-
size_t offset = emb.get_base_nbytes();
380-
Backend::write_tensor_data(emb.weight, buf.data(), offset, buf.size());
379+
size_t offset = emb->get_base_nbytes();
380+
Backend::write_tensor_data(emb->weight, buf.data(), offset, buf.size());
381381
}
382382

383383
public:

models/allenai.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ namespace moe
140140
void load(ModelLoader &loader) override
141141
{
142142
auto transformer = get_typed_transformer<ModelClass>();
143-
loader.read_tensor("model.embed_tokens.weight", transformer->word_embeddings.weight);
143+
transformer->word_embeddings->load("model.embed_tokens.", &loader);
144144
for (int i = 0; i < config.num_hidden_layers; i++)
145145
{
146146
std::string layer_prefix = "model.layers." + std::to_string(Base::layer_ids[i]) + '.';
@@ -165,7 +165,7 @@ namespace moe
165165
loader.read_tensor(layer_prefix + "self_attn.q_norm.weight", transformer->layers[i].attention.q_norm.weight);
166166
loader.read_tensor(layer_prefix + "self_attn.k_norm.weight", transformer->layers[i].attention.k_norm.weight);
167167
}
168-
loader.read_tensor("model.norm.weight", transformer->final_layernorm.weight);
168+
transformer->final_layernorm->load("model.norm.", &loader);
169169
loader.read_tensor("lm_head.weight", dynamic_cast<Linear *>(transformer->lm_head)->weight);
170170

171171
CHATLLM_CHECK(w_ctx_.get_used_mem() == w_ctx_.get_mem_size())
@@ -228,7 +228,7 @@ namespace dense
228228
void load(ModelLoader &loader) override
229229
{
230230
auto transformer = get_typed_transformer<ModelClass>();
231-
loader.read_tensor("model.embed_tokens.weight", transformer->word_embeddings.weight);
231+
transformer->word_embeddings->load("model.embed_tokens.", &loader);
232232
for (int i = 0; i < config.num_hidden_layers; i++)
233233
{
234234
std::string layer_prefix = "model.layers." + std::to_string(Base::layer_ids[i]) + '.';
@@ -251,7 +251,7 @@ namespace dense
251251
loader.read_tensor(layer_prefix + "self_attn.q_norm.weight", transformer->layers[i].attention.q_norm.weight);
252252
loader.read_tensor(layer_prefix + "self_attn.k_norm.weight", transformer->layers[i].attention.k_norm.weight);
253253
}
254-
loader.read_tensor("model.norm.weight", transformer->final_layernorm.weight);
254+
transformer->final_layernorm->load("model.norm.", &loader);
255255
loader.read_tensor("lm_head.weight", dynamic_cast<Linear *>(transformer->lm_head)->weight);
256256

257257
CHATLLM_CHECK(w_ctx_.get_used_mem() == w_ctx_.get_mem_size())

models/alphageo.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
558558
{
559559
auto transformer = get_typed_transformer<ModelClass>();
560560

561-
loader.read_tensor("model.embed_tokens.weight", transformer->word_embeddings.weight);
561+
transformer->word_embeddings->load("model.embed_tokens.", &loader);
562562

563563
for (int i = 0; i < config.num_hidden_layers; i++)
564564
{
@@ -578,7 +578,7 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
578578
loader.read_tensor(layer_prefix + "self_attn.v_proj.weight", transformer->layers[i].attention.v_proj.weight);
579579
}
580580

581-
loader.read_tensor("model.norm.weight", transformer->final_layernorm.weight);
581+
transformer->final_layernorm->load("model.norm.", &loader);
582582

583583
CHATLLM_CHECK(w_ctx_.get_used_mem() == w_ctx_.get_mem_size())
584584
<< "corrupted model weights";

models/baichuan.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ namespace m1
345345
class ConditionalGeneration : public BaseModelForConditionalGeneration
346346
{
347347
public:
348-
typedef HeterogeneousModel<BaseConfig, Embedding, RMSNorm> ModelClass;
348+
typedef HeterogeneousModel ModelClass;
349349

350350
public:
351351
ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config, ModelType type = MODEL_TYPE_BAICHUAN_M1)
@@ -377,7 +377,10 @@ namespace m1
377377
}
378378
};
379379

380-
transformer = new ModelClass(&w_ctx_, config, false, create_layer);
380+
transformer = new ModelClass(&w_ctx_, config.num_hidden_layers, config.hidden_size,
381+
create_embedding<Embedding>(&w_ctx_, config),
382+
create_final_norm<RMSNorm>(&w_ctx_, config),
383+
create_lm_head(&w_ctx_, config, false), create_layer);
381384

382385
for (int i = 0; i < config.num_hidden_layers; i++)
383386
{

models/chatglm.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ namespace v1
187187
{
188188
TransformerClass *transformer = dynamic_cast<TransformerClass *>(this->transformer);
189189

190-
loader.read_tensor("transformer.word_embeddings.weight", transformer->word_embeddings.weight);
190+
transformer->word_embeddings->load("transformer.word_embeddings.", &loader);
191191
for (int i = 0; i < config.num_hidden_layers; i++)
192192
{
193193
std::string layer_prefix = "transformer.layers." + std::to_string(layer_ids[i]) + '.';
@@ -208,8 +208,7 @@ namespace v1
208208
loader.read_tensor(layer_prefix + "mlp.dense_4h_to_h.weight", transformer->layers[i].mlp.fc1.weight);
209209
loader.read_tensor(layer_prefix + "mlp.dense_4h_to_h.bias", transformer->layers[i].mlp.fc1.bias);
210210
}
211-
loader.read_tensor("transformer.final_layernorm.weight", transformer->final_layernorm.weight);
212-
loader.read_tensor("transformer.final_layernorm.bias", transformer->final_layernorm.bias);
211+
transformer->final_layernorm->load("transformer.final_layernorm.", &loader);
213212

214213
CHATLLM_CHECK(w_ctx_.get_used_mem() == w_ctx_.get_mem_size())
215214
<< "corrupted model weights";
@@ -339,7 +338,7 @@ namespace v2
339338
void ConditionalGeneration::load(ModelLoader &loader)
340339
{
341340
TransformerClass *transformer = dynamic_cast<TransformerClass *>(this->transformer);
342-
loader.read_tensor("transformer.embedding.word_embeddings.weight", transformer->word_embeddings.weight);
341+
transformer->word_embeddings->load("transformer.embedding.word_embeddings.", &loader);
343342
for (int i = 0; i < config.num_hidden_layers; i++)
344343
{
345344
std::string layer_prefix = "transformer.encoder.layers." + std::to_string(layer_ids[i]) + '.';
@@ -354,7 +353,7 @@ namespace v2
354353
loader.read_tensor(layer_prefix + "mlp.dense_h_to_4h.weight", transformer->layers[i].mlp.dense_h_to_4h.weight);
355354
loader.read_tensor(layer_prefix + "mlp.dense_4h_to_h.weight", transformer->layers[i].mlp.dense_4h_to_h.weight);
356355
}
357-
loader.read_tensor("transformer.encoder.final_layernorm.weight", transformer->final_layernorm.weight);
356+
transformer->final_layernorm->load("transformer.encoder.final_layernorm.", &loader);
358357
loader.read_tensor("transformer.output_layer.weight", dynamic_cast<Linear *>(transformer->lm_head)->weight);
359358

360359
CHATLLM_CHECK(w_ctx_.get_used_mem() == w_ctx_.get_mem_size())

models/cohere.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
108108
{
109109
auto transformer = get_typed_transformer<ModelClass>();
110110

111-
loader.read_tensor("model.embed_tokens.weight", transformer->word_embeddings.weight);
111+
transformer->word_embeddings->load("model.embed_tokens.", &loader);
112112
for (int i = 0; i < config.num_hidden_layers; i++)
113113
{
114114
std::string layer_prefix = "model.layers." + std::to_string(layer_ids[i]) + '.';
@@ -122,7 +122,7 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
122122
loader.read_tensor(layer_prefix + "self_attn.q_proj.weight", transformer->layers[i].attention.q_proj.weight);
123123
loader.read_tensor(layer_prefix + "self_attn.v_proj.weight", transformer->layers[i].attention.v_proj.weight);
124124
}
125-
loader.read_tensor("model.norm.weight", transformer->final_layernorm.weight);
125+
transformer->final_layernorm->load("model.norm.", &loader);
126126

127127
CHATLLM_CHECK(w_ctx_.get_used_mem() == w_ctx_.get_mem_size())
128128
<< "corrupted model weights";
@@ -234,7 +234,7 @@ namespace v2
234234
class ConditionalGeneration : public BaseModelForConditionalGeneration
235235
{
236236
public:
237-
typedef HeterogeneousModel<BaseConfig, Embedding, LayerNormNoBias> ModelClass;
237+
typedef HeterogeneousModel ModelClass;
238238
typedef Cohere2SWABlock<SLIDING_WINDOW_LEN> Cohere2SWABlock4k;
239239

240240
public:
@@ -266,7 +266,10 @@ namespace v2
266266
}
267267
};
268268

269-
transformer = new ModelClass(&w_ctx_, config, nullptr, create_layer);
269+
transformer = new ModelClass(&w_ctx_, config.num_hidden_layers, config.hidden_size,
270+
create_embedding<Embedding>(&w_ctx_, config),
271+
create_final_norm<LayerNormNoBias>(&w_ctx_, config),
272+
nullptr, create_layer);
270273

271274
for (int i = 0; i < config.num_hidden_layers; i++)
272275
{
@@ -295,7 +298,7 @@ namespace v2
295298
loader.read_tensor(layer_prefix + "self_attn.q_proj.weight", layer->attention.q_proj.weight); \
296299
loader.read_tensor(layer_prefix + "self_attn.v_proj.weight", layer->attention.v_proj.weight);
297300

298-
loader.read_tensor("model.embed_tokens.weight", transformer->word_embeddings.weight);
301+
transformer->word_embeddings->load("model.embed_tokens.", &loader);
299302
for (int i = 0; i < config.num_hidden_layers; i++)
300303
{
301304
std::string layer_prefix = "model.layers." + std::to_string(layer_ids[i]) + '.';
@@ -310,7 +313,7 @@ namespace v2
310313
LOAD_TENSORS();
311314
}
312315
}
313-
loader.read_tensor("model.norm.weight", transformer->final_layernorm.weight);
316+
transformer->final_layernorm->load("model.norm.", &loader);
314317

315318
CHATLLM_CHECK(w_ctx_.get_used_mem() == w_ctx_.get_mem_size())
316319
<< "corrupted model weights: " << w_ctx_.get_used_mem() / ggml_tensor_overhead() << " != " << w_ctx_.get_mem_size() / ggml_tensor_overhead();

models/decilm.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
1919
{
2020
public:
2121
typedef BaseModelForConditionalGeneration Base;
22-
typedef HeterogeneousModel<Config, Embedding, RMSNorm> ModelClass;
22+
typedef HeterogeneousModel ModelClass;
2323

2424
public:
2525
ConditionalGeneration() = default;
@@ -63,8 +63,10 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
6363
}
6464
};
6565

66-
transformer = new ModelClass(
67-
&w_ctx_, config, false, create_layer);
66+
transformer = new ModelClass(&w_ctx_, config.num_hidden_layers, config.hidden_size,
67+
create_embedding<Embedding>(&w_ctx_, config),
68+
create_final_norm<RMSNorm>(&w_ctx_, config),
69+
create_lm_head(&w_ctx_, config, false), create_layer);
6870

6971
CHATLLM_CHECK(w_ctx_.get_used_mem() == w_ctx_.get_mem_size())
7072
<< "corrupted model weights: " << w_ctx_.get_used_mem() / ggml_tensor_overhead() << " vs "
@@ -74,8 +76,7 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
7476
void load(ModelLoader &loader) override
7577
{
7678
auto transformer = get_typed_transformer<ModelClass>();
77-
loader.read_tensor("model.embed_tokens.weight", transformer->word_embeddings.weight);
78-
79+
transformer->word_embeddings->load("model.embed_tokens.", &loader);
7980

8081
#define LOAD_MLP() \
8182
loader.read_tensor(layer_prefix + "mlp.down_proj.weight", layer->mlp.down_proj.weight); \
@@ -107,7 +108,7 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
107108
}
108109

109110
}
110-
loader.read_tensor("model.norm.weight", transformer->final_layernorm.weight);
111+
transformer->final_layernorm->load("model.norm.", &loader);
111112

112113
#undef LOAD_MLP
113114

models/deepseek.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ namespace v1_moe
139139
typedef CombinedMLP<DeepSeekSparseMoE<NUM_EXPERTS, EFFECTIVE_EXPERTS_PER_TOK>, SiLUMLP> DeepSeekMoEMLP;
140140
typedef LMBlock1<RMSNorm, LlamaSelfAttention, RMSNorm, DeepSeekMoEMLP> DeepSeekMoEBlock;
141141
typedef BaseModelForConditionalGeneration Base;
142-
typedef HeterogeneousModel<Config, Embedding, RMSNorm> ModelClass;
142+
typedef HeterogeneousModel ModelClass;
143143
public:
144144
ConditionalGeneration0() = default;
145145

@@ -184,7 +184,11 @@ namespace v1_moe
184184
}
185185
};
186186

187-
auto transformer = new ModelClass(&w_ctx_, config, false, create_layer);
187+
auto transformer = new ModelClass(&w_ctx_, config.num_hidden_layers, config.hidden_size,
188+
create_embedding<Embedding>(&w_ctx_, config),
189+
create_final_norm<RMSNorm>(&w_ctx_, config),
190+
create_lm_head(&w_ctx_, config, false), create_layer);
191+
188192
Base::transformer = transformer;
189193

190194
#define config_rope(attention) do { \
@@ -683,7 +687,7 @@ namespace v2_light
683687
typedef CombinedMLP<DeepSeekSparseMoE<NUM_EXPERTS, EFFECTIVE_EXPERTS_PER_TOK>, SiLUMLP> DeepSeekMoEMLP;
684688
typedef LMBlock1<RMSNorm, SpeedMLAttention, RMSNorm, DeepSeekMoEMLP> DeepSeek2MoEBlock;
685689
typedef BaseModelForConditionalGeneration Base;
686-
typedef HeterogeneousModel<Config, Embedding, RMSNorm> ModelClass;
690+
typedef HeterogeneousModel ModelClass;
687691
public:
688692
ConditionalGeneration0() = default;
689693

@@ -735,8 +739,10 @@ namespace v2_light
735739
}
736740
};
737741

738-
auto transformer = new ModelClass(
739-
&w_ctx_, config, false, create_layer);
742+
auto transformer = new ModelClass(&w_ctx_, config.num_hidden_layers, config.hidden_size,
743+
create_embedding<Embedding>(&w_ctx_, config),
744+
create_final_norm<RMSNorm>(&w_ctx_, config),
745+
create_lm_head(&w_ctx_, config, false), create_layer);
740746
Base::transformer = transformer;
741747

742748
float m = 1.0f;

models/gemma.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
105105
{
106106
auto transformer = get_typed_transformer<ModelClass>();
107107

108-
loader.read_tensor("model.embed_tokens.weight", transformer->word_embeddings.weight);
108+
transformer->word_embeddings->load("model.embed_tokens.", &loader);
109109
for (int i = 0; i < config.num_hidden_layers; i++)
110110
{
111111
std::string layer_prefix = "model.layers." + std::to_string(layer_ids[i]) + '.';
@@ -120,7 +120,7 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
120120
loader.read_tensor(layer_prefix + "self_attn.q_proj.weight", transformer->layers[i].attention.q_proj.weight);
121121
loader.read_tensor(layer_prefix + "self_attn.v_proj.weight", transformer->layers[i].attention.v_proj.weight);
122122
}
123-
loader.read_tensor("model.norm.weight", transformer->final_layernorm.weight);
123+
transformer->final_layernorm->load("model.norm.", &loader);
124124

125125
CHATLLM_CHECK(w_ctx_.get_used_mem() == w_ctx_.get_mem_size())
126126
<< "corrupted model weights";
@@ -257,7 +257,7 @@ template <class Layer> static void load_layer(ModelLoader &loader, const std::st
257257
class ConditionalGeneration : public BaseModelForConditionalGeneration
258258
{
259259
public:
260-
typedef HeterogeneousModel<BaseConfig, Embedding, RMSNorm> ModelClass;
260+
typedef HeterogeneousModel ModelClass;
261261
public:
262262
ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config, ModelType type = MODEL_TYPE_GEMMA2)
263263
: BaseModelForConditionalGeneration(type, config, runtime_config), config(config),
@@ -286,7 +286,11 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
286286
}
287287
};
288288

289-
transformer = new ModelClass(&w_ctx_, config, nullptr, create_layer);
289+
transformer = new ModelClass(&w_ctx_, config.num_hidden_layers, config.hidden_size,
290+
create_embedding<Embedding>(&w_ctx_, config),
291+
create_final_norm<RMSNorm>(&w_ctx_, config),
292+
nullptr,
293+
create_layer);
290294

291295
get_typed_transformer<ModelClass>()->logits_pp = &logits_pp;
292296

@@ -309,7 +313,7 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
309313
{
310314
auto transformer = get_typed_transformer<ModelClass>();
311315

312-
loader.read_tensor("model.embed_tokens.weight", transformer->word_embeddings.weight);
316+
transformer->word_embeddings->load("model.embed_tokens.", &loader);
313317
for (int i = 0; i < config.num_hidden_layers; i++)
314318
{
315319
std::string layer_prefix = "model.layers." + std::to_string(layer_ids[i]) + '.';
@@ -322,7 +326,7 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
322326
load_layer<Gemma2FullBlock>(loader, layer_prefix, transformer->get_layer(i));
323327
}
324328
}
325-
loader.read_tensor("model.norm.weight", transformer->final_layernorm.weight);
329+
transformer->final_layernorm->load("model.norm.", &loader);
326330

327331
CHATLLM_CHECK(w_ctx_.get_used_mem() == w_ctx_.get_mem_size())
328332
<< "corrupted model weights";
@@ -763,7 +767,7 @@ template <class Layer> static void setup_layer(Block *block, const Config &confi
763767
class ConditionalGeneration : public BaseModelForConditionalGeneration
764768
{
765769
public:
766-
typedef HeterogeneousModel<BaseConfig, Embedding, RMSNorm> ModelClass;
770+
typedef HeterogeneousModel ModelClass;
767771
public:
768772
ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config, ModelType type = MODEL_TYPE_GEMMA3)
769773
: BaseModelForConditionalGeneration(type, config, runtime_config, 4096 * 2), config(config),
@@ -804,7 +808,10 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
804808
BlockParams::PadEmbedding padding(1024, 1024); // 4 media_emb
805809
_chat_encoder.MAX_PATCH_NUM = padding.get();
806810

807-
transformer = new ModelClass(&w_ctx_, config, nullptr, create_layer);
811+
transformer = new ModelClass(&w_ctx_, config.num_hidden_layers, config.hidden_size,
812+
create_embedding<Embedding>(&w_ctx_, config),
813+
create_final_norm<RMSNorm>(&w_ctx_, config),
814+
nullptr, create_layer);
808815

809816
for (int i = 0; i < config.num_hidden_layers; i++)
810817
{
@@ -856,12 +863,12 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
856863
void before_generate(const GenerationConfig &gen_config) override
857864
{
858865
std::vector<uint8_t> buf;
859-
auto &emb = dynamic_cast<ModelClass *>(transformer)->word_embeddings;
860-
visual.generate(gen_config, dynamic_cast<Tokenizer *>(tokenizer), ggml::type_of(emb.weight), buf);
866+
auto emb = dynamic_cast<Embedding *>(dynamic_cast<ModelClass *>(transformer)->word_embeddings);
867+
visual.generate(gen_config, dynamic_cast<Tokenizer *>(tokenizer), ggml::type_of(emb->weight), buf);
861868
if (buf.size() < 1) return;
862869

863-
size_t offset = emb.get_base_nbytes();
864-
Backend::write_tensor_data(emb.weight, buf.data(), offset, buf.size());
870+
size_t offset = emb->get_base_nbytes();
871+
Backend::write_tensor_data(emb->weight, buf.data(), offset, buf.size());
865872
}
866873

867874
public:

0 commit comments

Comments
 (0)