Skip to content

Commit 7f90e16

Browse files
committed
refactor: refactor lm_head and embedding layer.
1 parent a1b5802 commit 7f90e16

32 files changed

+273
-795
lines changed

xllm/core/framework/model/causal_lm.h

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,8 @@ limitations under the License.
2626
#include "core/framework/parallel_state.h"
2727
#include "core/framework/quant_args.h"
2828
#include "core/framework/state_dict/state_dict.h"
29-
#if defined(USE_NPU)
30-
#include "layers/npu/llm_head.h"
31-
#include "layers/npu/word_embedding.h"
32-
#endif
29+
#include "layers/lm_head.h"
30+
#include "layers/word_embedding.h"
3331
#include "model_args.h"
3432
#include "model_input_params.h"
3533

@@ -65,10 +63,10 @@ class CausalLM : public torch::nn::Module {
6563
virtual const torch::TensorOptions& options() const = 0;
6664

6765
#if defined(USE_NPU)
68-
virtual hf::LlmHead get_lm_head() = 0;
69-
virtual void set_lm_head(hf::LlmHead& head) = 0;
70-
virtual hf::AtbWordEmbedding get_word_embedding() = 0;
71-
virtual void set_word_embedding(hf::AtbWordEmbedding& embedding) = 0;
66+
virtual LmHead get_lm_head() = 0;
67+
virtual void set_lm_head(LmHead& head) = 0;
68+
virtual WordEmbedding get_word_embedding() = 0;
69+
virtual void set_word_embedding(WordEmbedding& embedding) = 0;
7270
#endif
7371
};
7472

@@ -104,15 +102,15 @@ class CausalLMImpl : public CausalLM {
104102
}
105103

106104
#if defined(USE_NPU)
107-
hf::LlmHead get_lm_head() override { return model_->get_lm_head(); };
105+
LmHead get_lm_head() override { return model_->get_lm_head(); };
108106

109-
void set_lm_head(hf::LlmHead& head) override { model_->set_lm_head(head); };
107+
void set_lm_head(LmHead& head) override { model_->set_lm_head(head); };
110108

111-
hf::AtbWordEmbedding get_word_embedding() override {
109+
WordEmbedding get_word_embedding() override {
112110
return model_->get_word_embedding();
113111
};
114112

115-
void set_word_embedding(hf::AtbWordEmbedding& embedding) override {
113+
void set_word_embedding(WordEmbedding& embedding) override {
116114
model_->set_word_embedding(embedding);
117115
};
118116
#endif

xllm/core/framework/model/causal_vlm.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,15 @@ class CausalVLMImpl : public CausalVLM {
6565
virtual void update_expert_weight(int32_t layer_id) { return; }
6666

6767
#if defined(USE_NPU)
68-
hf::LlmHead get_lm_head() override { return model_->get_lm_head(); };
68+
LmHead get_lm_head() override { return model_->get_lm_head(); };
6969

70-
void set_lm_head(hf::LlmHead& head) override { model_->set_lm_head(head); };
70+
void set_lm_head(LmHead& head) override { model_->set_lm_head(head); };
7171

72-
hf::AtbWordEmbedding get_word_embedding() override {
72+
WordEmbedding get_word_embedding() override {
7373
return model_->get_word_embedding();
7474
};
7575

76-
void set_word_embedding(hf::AtbWordEmbedding& embedding) override {
76+
void set_word_embedding(WordEmbedding& embedding) override {
7777
model_->set_word_embedding(embedding);
7878
};
7979
#endif

xllm/core/layers/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ cc_library(
5252
qwen3_moe_decoder_layer.h
5353
rms_norm.h
5454
siglip_encoder_layer.h
55+
pos_embedding.h
56+
word_embedding.h
57+
lm_head.h
5558
SRCS
5659
multi_head_attention.cpp
5760
DEPS

xllm/core/layers/npu/llm_head.cpp renamed to xllm/core/layers/lm_head.h

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,22 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#include "llm_head.h"
16+
#pragma once
1717

18-
#include "atb_head_impl.h"
18+
#if defined(USE_NPU)
19+
#include "npu/npu_lm_head_impl.h"
20+
#include "pytorch/adapter/utils/utils.h"
21+
#endif
1922

20-
namespace xllm::hf {
23+
namespace xllm {
2124

22-
std::shared_ptr<LlmHeadImpl> create_llm_head_layer(const Context& context) {
23-
return std::make_shared<AtbLmHeadImpl>(context);
24-
}
25+
class LmHead : public torch::nn::ModuleHolder<NpuLmHeadImpl> {
26+
public:
27+
using torch::nn::ModuleHolder<NpuLmHeadImpl>::ModuleHolder;
28+
using Impl __attribute__((__unused__)) = NpuLmHeadImpl;
2529

26-
LlmHead::LlmHead(const Context& context)
27-
: ModuleHolder(create_llm_head_layer(context)) {}
30+
LmHead(const Context& context)
31+
: ModuleHolder(std::make_shared<NpuLmHeadImpl>(context)) {}
32+
};
2833

29-
} // namespace xllm::hf
34+
} // namespace xllm

xllm/core/layers/npu/CMakeLists.txt

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,9 @@ cc_library(
66
NAME
77
npu_layers
88
HDRS
9-
atb_base.h
10-
word_embedding.h
11-
atb_word_embedding_impl.h
12-
pos_embedding.h
13-
llm_head.h
14-
atb_head_impl.h
9+
npu_word_embedding_impl.h
10+
npu_pos_embedding_impl.h
11+
npu_lm_head_impl.h
1512
$<$<BOOL:${USE_A2}>:npu_qwen2dot5_vision_encoder_layer_impl.h>
1613
$<$<BOOL:${USE_A2}>:npu_qwen3_moe_decoder_layer_impl.h>
1714
# atb_parallel_linear.h
@@ -26,14 +23,11 @@ cc_library(
2623
npu_rms_norm_impl.h
2724
npu_siglip_encoder_layer_impl.h
2825
SRCS
26+
npu_word_embedding_impl.cpp
27+
npu_pos_embedding_impl.cpp
28+
npu_lm_head_impl.cpp
2929
$<$<BOOL:${USE_A2}>:npu_qwen2dot5_vision_encoder_layer_impl.cpp>
3030
$<$<BOOL:${USE_A2}>:npu_qwen3_moe_decoder_layer_impl.cpp>
31-
atb_base.cpp
32-
word_embedding.cpp
33-
atb_word_embedding_impl.cpp
34-
pos_embedding.cpp
35-
llm_head.cpp
36-
atb_head_impl.cpp
3731
# atb_parallel_linear.cpp
3832
buffer/atb_buffer.cpp
3933
buffer/atb_workspace.cpp

xllm/core/layers/npu/atb_base.cpp

Lines changed: 0 additions & 250 deletions
This file was deleted.

0 commit comments

Comments
 (0)