Skip to content

Commit bc12eb7

Browse files
committed
refactor: update causal LM implementations to inherit from LlmForCausalLMImplBase.
1 parent f2baa8d commit bc12eb7

File tree

6 files changed

+34
-325
lines changed

6 files changed

+34
-325
lines changed

xllm/models/llm/npu/deepseek_v2.h

Lines changed: 9 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ limitations under the License.
3535
#include "core/layers/npu/npu_rms_norm_impl.h"
3636
#include "core/layers/npu/npu_word_embedding_impl.h"
3737
#include "core/layers/npu/rotary_embedding.h"
38+
#include "llm_model_base.h"
3839
#include "models/model_registry.h"
3940
// DeepSeek v2 compatible with huggingface weights
4041
// ref to:
@@ -265,72 +266,25 @@ class DeepseekV2ModelImpl : public torch::nn::Module {
265266
};
266267
TORCH_MODULE(DeepseekV2Model);
267268

268-
class DeepseekV2ForCausalLMImpl : public torch::nn::Module {
269+
class DeepseekV2ForCausalLMImpl
270+
: public LlmForCausalLMImplBase<DeepseekV2Model> {
269271
public:
270-
DeepseekV2ForCausalLMImpl(const ModelContext& context) {
271-
model_ = register_module("model", DeepseekV2Model(context));
272-
npu_lm_head_ = register_module("lm_head", layer::NpuLmHead(context));
273-
first_k_dense_replace_ = context.get_model_args().first_k_dense_replace();
274-
}
275-
276-
// tokens: [num_tokens]
277-
// positions: [num_tokens] token pos in the sequence
278-
// returns: [num_tokens, hidden_size]
279-
torch::Tensor forward(const torch::Tensor& tokens,
280-
const torch::Tensor& positions,
281-
std::vector<KVCache>& kv_caches,
282-
const ModelInputParams& input_params) {
283-
return model_(tokens, positions, kv_caches, input_params);
284-
}
285-
286-
// hidden_states: [num_tokens, hidden_size]
287-
// seleted_idxes: [num_tokens]
288-
// returns: [num_tokens, vocab_size]
289-
torch::Tensor logits(const torch::Tensor& hidden_states,
290-
const torch::Tensor& seleted_idxes) {
291-
return npu_lm_head_(hidden_states, seleted_idxes, 0);
292-
}
293-
294-
void load_model(std::unique_ptr<ModelLoader> loader) {
295-
for (const auto& state_dict : loader->get_state_dicts()) {
296-
model_->load_state_dict(state_dict->get_dict_with_prefix("model."));
297-
npu_lm_head_->load_state_dict(
298-
state_dict->get_dict_with_prefix("lm_head."));
299-
}
300-
301-
// verify
302-
model_->verify_loaded_weights("model.");
303-
npu_lm_head_->verify_loaded_weights("lm_head.");
304-
305-
model_->merge_loaded_weights();
306-
npu_lm_head_->merge_loaded_weights();
307-
}
272+
DeepseekV2ForCausalLMImpl(const ModelContext& context)
273+
: LlmForCausalLMImplBase<DeepseekV2Model>(context),
274+
first_k_dense_replace_(
275+
context.get_model_args().first_k_dense_replace()) {}
308276

309277
void prepare_expert_weight(int32_t layer_id,
310-
const std::vector<int32_t>& expert_ids) {
278+
const std::vector<int32_t>& expert_ids) override {
311279
model_->prepare_expert_weight(layer_id + first_k_dense_replace_,
312280
expert_ids);
313281
}
314282

315-
void update_expert_weight(int32_t layer_id) {
283+
void update_expert_weight(int32_t layer_id) override {
316284
model_->update_expert_weight(layer_id + first_k_dense_replace_);
317285
}
318286

319-
layer::NpuLmHead get_npu_lm_head() { return npu_lm_head_; }
320-
321-
void set_npu_lm_head(layer::NpuLmHead& head) { npu_lm_head_ = head; }
322-
323-
layer::NpuWordEmbedding get_npu_word_embedding() {
324-
return model_->get_npu_word_embedding();
325-
}
326-
327-
void set_npu_word_embedding(layer::NpuWordEmbedding& npu_word_embedding) {
328-
model_->set_npu_word_embedding(npu_word_embedding);
329-
}
330-
331287
private:
332-
DeepseekV2Model model_{nullptr};
333-
layer::NpuLmHead npu_lm_head_{nullptr};
334288
int32_t first_k_dense_replace_;
335289
};
336290
TORCH_MODULE(DeepseekV2ForCausalLM);

xllm/models/llm/npu/deepseek_v32.h

Lines changed: 10 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,9 @@ limitations under the License.
1515

1616
#pragma once
1717

18-
#include <torch/torch.h>
19-
20-
#include <string>
21-
#include <vector>
22-
23-
#include "core/common/global_flags.h"
24-
#include "core/framework/kv_cache/kv_cache.h"
25-
#include "core/framework/model/model_input_params.h"
26-
#include "core/framework/model/npu_dp_ep_padding.h"
27-
#include "core/framework/model_context.h"
28-
#include "core/layers/common/attention_mask.h"
29-
#include "core/layers/common/rotary_embedding_util.h"
3018
#include "core/layers/npu/npu_deepseek_v32_decoder_layer_impl.h"
31-
#include "core/layers/npu/npu_lm_head_impl.h"
32-
#include "core/layers/npu/npu_pos_embedding_impl.h"
33-
#include "core/layers/npu/npu_rms_norm_impl.h"
34-
#include "core/layers/npu/npu_word_embedding_impl.h"
35-
#include "models/model_registry.h"
19+
#include "llm_model_base.h"
20+
3621
// DeepSeek v32 compatible with huggingface weights
3722
// ref to:
3823
// https://github.com/vllm-project/vllm/blob/v0.6.6/vllm/model_executor/models/deepseek_v2.py
@@ -263,72 +248,25 @@ class DeepseekV32ModelImpl : public torch::nn::Module {
263248
};
264249
TORCH_MODULE(DeepseekV32Model);
265250

266-
class DeepseekV32ForCausalLMImpl : public torch::nn::Module {
251+
class DeepseekV32ForCausalLMImpl
252+
: public LlmForCausalLMImplBase<DeepseekV32Model> {
267253
public:
268-
DeepseekV32ForCausalLMImpl(const ModelContext& context) {
269-
model_ = register_module("model", DeepseekV32Model(context));
270-
npu_lm_head_ = register_module("lm_head", layer::NpuLmHead(context));
271-
first_k_dense_replace_ = context.get_model_args().first_k_dense_replace();
272-
}
273-
274-
// tokens: [num_tokens]
275-
// positions: [num_tokens] token pos in the sequence
276-
// returns: [num_tokens, hidden_size]
277-
torch::Tensor forward(const torch::Tensor& tokens,
278-
const torch::Tensor& positions,
279-
std::vector<KVCache>& kv_caches,
280-
const ModelInputParams& input_params) {
281-
return model_(tokens, positions, kv_caches, input_params);
282-
}
283-
284-
// hidden_states: [num_tokens, hidden_size]
285-
// seleted_idxes: [num_tokens]
286-
// returns: [num_tokens, vocab_size]
287-
torch::Tensor logits(const torch::Tensor& hidden_states,
288-
const torch::Tensor& seleted_idxes) {
289-
return npu_lm_head_(hidden_states, seleted_idxes, 0);
290-
}
291-
292-
void load_model(std::unique_ptr<ModelLoader> loader) {
293-
for (const auto& state_dict : loader->get_state_dicts()) {
294-
model_->load_state_dict(state_dict->get_dict_with_prefix("model."));
295-
npu_lm_head_->load_state_dict(
296-
state_dict->get_dict_with_prefix("lm_head."));
297-
}
298-
299-
// verify
300-
model_->verify_loaded_weights("model.");
301-
npu_lm_head_->verify_loaded_weights("lm_head.");
302-
303-
model_->merge_loaded_weights();
304-
npu_lm_head_->merge_loaded_weights();
305-
}
254+
DeepseekV32ForCausalLMImpl(const ModelContext& context)
255+
: LlmForCausalLMImplBase<DeepseekV32Model>(context),
256+
first_k_dense_replace_(
257+
context.get_model_args().first_k_dense_replace()) {}
306258

307259
void prepare_expert_weight(int32_t layer_id,
308-
const std::vector<int32_t>& expert_ids) {
260+
const std::vector<int32_t>& expert_ids) override {
309261
model_->prepare_expert_weight(layer_id + first_k_dense_replace_,
310262
expert_ids);
311263
}
312264

313-
void update_expert_weight(int32_t layer_id) {
265+
void update_expert_weight(int32_t layer_id) override {
314266
model_->update_expert_weight(layer_id + first_k_dense_replace_);
315267
}
316268

317-
layer::NpuLmHead get_npu_lm_head() { return npu_lm_head_; }
318-
319-
void set_npu_lm_head(layer::NpuLmHead& head) { npu_lm_head_ = head; }
320-
321-
layer::NpuWordEmbedding get_npu_word_embedding() {
322-
return model_->get_npu_word_embedding();
323-
}
324-
325-
void set_npu_word_embedding(layer::NpuWordEmbedding& npu_word_embedding) {
326-
model_->set_npu_word_embedding(npu_word_embedding);
327-
}
328-
329269
private:
330-
DeepseekV32Model model_{nullptr};
331-
layer::NpuLmHead npu_lm_head_{nullptr};
332270
int32_t first_k_dense_replace_;
333271
};
334272
TORCH_MODULE(DeepseekV32ForCausalLM);

xllm/models/llm/npu/glm4_moe.h

Lines changed: 3 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -286,74 +286,10 @@ class Glm4MoeModelImpl : public torch::nn::Module {
286286
};
287287
TORCH_MODULE(Glm4MoeModel);
288288

289-
class Glm4MoeForCausalLMImpl : public torch::nn::Module {
289+
class Glm4MoeForCausalLMImpl : public LlmForCausalLMImplBase<Glm4MoeModel> {
290290
public:
291-
Glm4MoeForCausalLMImpl(const ModelContext& context) {
292-
model_ = register_module("model", Glm4MoeModel(context));
293-
npu_lm_head_ = register_module("lm_head", layer::NpuLmHead(context));
294-
}
295-
296-
torch::Tensor get_input_embeddings(torch::Tensor input_ids) {
297-
return model_->get_input_embeddings(input_ids);
298-
}
299-
300-
// tokens: [num_tokens]
301-
// positions: [num_tokens] token pos in the sequence
302-
// returns: [num_tokens, hidden_size]
303-
torch::Tensor forward(const torch::Tensor& tokens,
304-
const torch::Tensor& positions,
305-
std::vector<KVCache>& kv_caches,
306-
const ModelInputParams& input_params) {
307-
return model_(tokens, positions, kv_caches, input_params);
308-
}
309-
310-
// hidden_states: [num_tokens, hidden_size]
311-
// seleted_idxes: [num_tokens]
312-
// returns: [num_tokens, vocab_size]
313-
torch::Tensor logits(const torch::Tensor& hidden_states,
314-
const torch::Tensor& seleted_idxes) {
315-
// select tokens if provided
316-
auto h = hidden_states;
317-
return npu_lm_head_(hidden_states, seleted_idxes, 0);
318-
}
319-
320-
void load_model(std::unique_ptr<ModelLoader> loader,
321-
std::string prefix = "model." /*llm model weight prefix*/) {
322-
for (const auto& state_dict : loader->get_state_dicts()) {
323-
model_->load_state_dict(state_dict->get_dict_with_prefix(prefix));
324-
npu_lm_head_->load_state_dict(
325-
state_dict->get_dict_with_prefix("lm_head."));
326-
}
327-
328-
// verify
329-
model_->verify_loaded_weights(prefix);
330-
npu_lm_head_->verify_loaded_weights("lm_head.");
331-
332-
model_->merge_loaded_weights();
333-
npu_lm_head_->merge_loaded_weights();
334-
}
335-
336-
virtual void prepare_expert_weight(int32_t layer_id,
337-
const std::vector<int32_t>& expert_ids) {
338-
return;
339-
}
340-
virtual void update_expert_weight(int32_t layer_id) { return; }
341-
342-
layer::NpuLmHead get_npu_lm_head() { return npu_lm_head_; }
343-
344-
void set_npu_lm_head(layer::NpuLmHead& head) { npu_lm_head_ = head; }
345-
346-
layer::NpuWordEmbedding get_npu_word_embedding() {
347-
return model_->get_npu_word_embedding();
348-
}
349-
350-
void set_npu_word_embedding(layer::NpuWordEmbedding& npu_word_embedding) {
351-
model_->set_npu_word_embedding(npu_word_embedding);
352-
}
353-
354-
private:
355-
Glm4MoeModel model_{nullptr};
356-
layer::NpuLmHead npu_lm_head_{nullptr};
291+
Glm4MoeForCausalLMImpl(const ModelContext& context)
292+
: LlmForCausalLMImplBase<Glm4MoeModel>(context) {}
357293
};
358294
TORCH_MODULE(Glm4MoeForCausalLM);
359295

xllm/models/llm/npu/llama.h

Lines changed: 4 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ limitations under the License.
3030
#include "core/layers/npu/npu_llama_decoder_layer_impl.h"
3131
#include "core/layers/npu/npu_rms_norm_impl.h"
3232
#include "core/util/tensor_helper.h"
33+
#include "llm_model_base.h"
3334
#include "models/model_registry.h"
3435
#include "xllm_kernels/core/include/atb_speed/log.h"
3536

@@ -234,72 +235,10 @@ class LlamaModelImpl : public torch::nn::Module {
234235
};
235236
TORCH_MODULE(LlamaModel);
236237

237-
class LlamaForCausalLMImpl : public torch::nn::Module {
238+
class LlamaForCausalLMImpl : public LlmForCausalLMImplBase<LlamaModel> {
238239
public:
239-
LlamaForCausalLMImpl(const ModelContext& context) {
240-
auto options = context.get_tensor_options();
241-
242-
// register submodules
243-
model_ = register_module("model", LlamaModel(context));
244-
device_id_ = options.device().index();
245-
npu_lm_head_ = register_module("lm_head", layer::NpuLmHead(context));
246-
}
247-
// tokens: [num_tokens]
248-
// positions: [num_tokens] token pos in the sequence
249-
// returns: [num_tokens, hidden_size]
250-
torch::Tensor forward(const torch::Tensor& tokens,
251-
const torch::Tensor& positions,
252-
std::vector<KVCache>& kv_caches,
253-
const ModelInputParams& input_params) {
254-
return model_(tokens, positions, kv_caches, input_params);
255-
}
256-
257-
// hidden_states: [num_tokens, hidden_size]
258-
// seleted_idxes: [num_tokens]
259-
// returns: [num_tokens, vocab_size]
260-
torch::Tensor logits(const torch::Tensor& hidden_states,
261-
const torch::Tensor& seleted_idxes) {
262-
return npu_lm_head_(hidden_states, seleted_idxes, 0);
263-
}
264-
265-
void load_model(std::unique_ptr<ModelLoader> loader) {
266-
for (const auto& state_dict : loader->get_state_dicts()) {
267-
model_->load_state_dict(state_dict->get_dict_with_prefix("model."));
268-
npu_lm_head_->load_state_dict(
269-
state_dict->get_dict_with_prefix("lm_head."));
270-
}
271-
272-
// verify
273-
model_->verify_loaded_weights("model.");
274-
npu_lm_head_->verify_loaded_weights("lm_head.");
275-
276-
model_->merge_loaded_weights();
277-
npu_lm_head_->merge_loaded_weights();
278-
}
279-
280-
void prepare_expert_weight(int32_t layer_id,
281-
const std::vector<int32_t>& expert_ids) {
282-
return;
283-
}
284-
void update_expert_weight(int32_t layer_id) { return; }
285-
286-
layer::NpuLmHead get_npu_lm_head() { return npu_lm_head_; }
287-
288-
void set_npu_lm_head(layer::NpuLmHead& head) { npu_lm_head_ = head; }
289-
290-
layer::NpuWordEmbedding get_npu_word_embedding() {
291-
return model_->get_npu_word_embedding();
292-
}
293-
294-
void set_npu_word_embedding(layer::NpuWordEmbedding& npu_word_embedding) {
295-
model_->set_npu_word_embedding(npu_word_embedding);
296-
}
297-
298-
private:
299-
// parameter members, must be registered
300-
LlamaModel model_{nullptr};
301-
int device_id_ = 0;
302-
layer::NpuLmHead npu_lm_head_{nullptr};
240+
LlamaForCausalLMImpl(const ModelContext& context)
241+
: LlmForCausalLMImplBase<LlamaModel>(context) {}
303242
};
304243
TORCH_MODULE(LlamaForCausalLM);
305244

xllm/models/llm/npu/llm_model_base.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,11 @@ class LlmForCausalLMImplBase : public torch::nn::Module {
341341

342342
// verify
343343
model_->verify_loaded_weights(prefix);
344-
npu_lm_head_->verify_loaded_weights("lm_head.");
344+
if (tie_word_embeddings) {
345+
npu_lm_head_->verify_loaded_weights(prefix + "embed_tokens.");
346+
} else {
347+
npu_lm_head_->verify_loaded_weights("lm_head.");
348+
}
345349

346350
model_->merge_loaded_weights();
347351
// test

0 commit comments

Comments
 (0)