Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions xllm/core/framework/model/causal_lm.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ limitations under the License.
#include "graph/types.h"
#include "layers/npu/npu_lm_head_impl.h"
#include "layers/npu/npu_word_embedding_impl.h"
#else
#include "layers/lm_head.h"
#include "layers/word_embedding.h"
#endif
// clang-format on
#include <c10/core/Device.h>
Expand All @@ -34,12 +31,13 @@ limitations under the License.
#include "core/framework/model_loader.h"
#include "core/framework/quant_args.h"
#include "core/framework/state_dict/state_dict.h"
#include "layers/lm_head.h"
#include "layers/word_embedding.h"
#include "model_args.h"
#include "model_input_params.h"

namespace xllm {

#if !defined(USE_NPU)
namespace detail {
template <typename T, typename = void>
struct has_get_lm_head : std::false_type {};
Expand Down Expand Up @@ -76,7 +74,6 @@ struct has_set_word_embedding<
std::void_t<decltype(std::declval<T>()->set_word_embedding(
std::declval<layer::WordEmbedding&>()))>> : std::true_type {};
} // namespace detail
#endif

class CausalLM : public torch::nn::Module {
public:
Expand Down Expand Up @@ -113,7 +110,7 @@ class CausalLM : public torch::nn::Module {
virtual void set_npu_lm_head(layer::NpuLmHead& head) = 0;
virtual layer::NpuWordEmbedding get_npu_word_embedding() = 0;
virtual void set_npu_word_embedding(layer::NpuWordEmbedding& embedding) = 0;
#else
#endif
virtual layer::LmHead get_lm_head() {
LOG(FATAL)
<< "Method 'get_lm_head' is not implemented/supported by this model.";
Expand All @@ -130,7 +127,6 @@ class CausalLM : public torch::nn::Module {
LOG(FATAL) << "Method 'set_word_embedding' is not implemented/supported by "
"this model.";
}
#endif
};

template <typename Model>
Expand Down Expand Up @@ -180,7 +176,7 @@ class CausalLMImpl : public CausalLM {
void set_npu_word_embedding(layer::NpuWordEmbedding& embedding) override {
model_->set_npu_word_embedding(embedding);
}
#else
#endif
layer::LmHead get_lm_head() override {
if constexpr (detail::has_get_lm_head<Model>::value) {
return model_->get_lm_head();
Expand Down Expand Up @@ -212,7 +208,6 @@ class CausalLMImpl : public CausalLM {
CausalLM::set_word_embedding(embedding);
}
}
#endif

torch::Device device() const override { return options_.device(); }

Expand Down
3 changes: 1 addition & 2 deletions xllm/core/framework/model/causal_vlm.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class CausalVLMImpl : public CausalVLM {
void set_npu_word_embedding(layer::NpuWordEmbedding& embedding) override {
model_->set_npu_word_embedding(embedding);
}
#else
#endif
layer::LmHead get_lm_head() override {
if constexpr (detail::has_get_lm_head<Model>::value) {
return model_->get_lm_head();
Expand Down Expand Up @@ -111,7 +111,6 @@ class CausalVLMImpl : public CausalVLM {
CausalLM::set_word_embedding(embedding);
}
}
#endif

torch::Device device() const override { return options_.device(); }

Expand Down
3 changes: 1 addition & 2 deletions xllm/core/framework/model/mm_embedding_vlm.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,11 @@ class MMEmbeddingVLMImpl : public MMEmbeddingVLM {
virtual void set_npu_word_embedding(layer::NpuWordEmbedding& embedding) {
return;
}
#else
#endif
virtual void set_lm_head(layer::LmHead& head) { return; }
virtual layer::LmHead get_lm_head() { return nullptr; }
virtual layer::WordEmbedding get_word_embedding() { return nullptr; }
virtual void set_word_embedding(layer::WordEmbedding& embedding) { return; }
#endif

void load_model(std::unique_ptr<ModelLoader> loader) override {
model_->load_model(std::move(loader));
Expand Down
97 changes: 10 additions & 87 deletions xllm/core/layers/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,103 +39,26 @@ limitations under the License.
} \
}

#if defined(USE_NPU)
#include "npu/npu_word_embedding_impl.h"
#else
#include "common/word_embedding_impl.h"
#endif

#if defined(USE_NPU)
#include "npu/npu_pos_embedding_impl.h"
#else
#include "common/rotary_embedding.h"
#endif

#if defined(USE_NPU)
#include "npu/npu_lm_head_impl.h"
#else
#include "common/linear.h"
UNIFY_CLASS_NAME(ColumnParallelLinearImpl, LmHeadImpl)
#endif
#include "common/qwen2_5_vision_layer.h"
#include "common/qwen2_decoder_layer.h"
#include "common/qwen3_moe_decoder_layer.h"
#include "common/rotary_embedding.h"
#include "common/word_embedding_impl.h"

#if defined(USE_NPU)
#include "npu/npu_deepseek_v2_decoder_layer_impl.h"
#elif defined(USE_MLU)
#if defined(USE_MLU)
#include "mlu/deepseek_v2_decoder_layer_impl.h"
#else
REGISTER_NOT_IMPLEMENTED_CLASS(DeepseekV2DecoderLayerImpl);
#endif

#if defined(USE_NPU)
#include "npu/npu_deepseek_v32_decoder_layer_impl.h"
#else
REGISTER_NOT_IMPLEMENTED_CLASS(DeepseekV32DecoderLayerImpl);
#endif

#if defined(USE_NPU)
#include "npu/npu_llama_decoder_layer_impl.h"
#else
REGISTER_NOT_IMPLEMENTED_CLASS(LlamaDecoderLayerImpl);
#endif

#if defined(USE_NPU)
#include "npu/npu_qwen2_decoder_layer_impl.h"
#else
#include "common/qwen2_decoder_layer.h"
#endif

#if defined(USE_NPU)
#include "npu/npu_qwen2_vision_encoder_layer_impl.h"
#else
#include "common/qwen2_5_vision_layer.h"
UNIFY_CLASS_NAME(ColumnParallelLinearImpl, LmHeadImpl)
UNIFY_CLASS_NAME(Qwen2_VisionLayerImpl, Qwen2VisionEncoderLayerImpl)
#endif

#if defined(USE_NPU)
#include "npu/npu_qwen2dot5_vision_encoder_layer_impl.h"
#else
#include "common/qwen2_5_vision_layer.h"
UNIFY_CLASS_NAME(Qwen2_5_VisionLayerImpl, Qwen2dot5VisionEncoderLayerImpl)
#endif

#if defined(USE_NPU)
#include "npu/npu_qwen3_decoder_layer_impl.h"
#else
#include "common/qwen2_decoder_layer.h"
#endif

#if defined(USE_NPU)
#include "npu/npu_qwen3_moe_decoder_layer_impl.h"
#else
#include "common/qwen3_moe_decoder_layer.h"
#endif

#if defined(USE_NPU)
#include "npu/npu_qwen3_vision_encoder_layer_impl.h"
#else
#include "common/qwen2_5_vision_layer.h"
UNIFY_CLASS_NAME(Qwen3_VisionLayerImpl, Qwen3VisionEncoderLayerImpl)
#endif

#if defined(USE_NPU)
#include "npu/npu_siglip_encoder_layer_impl.h"
#else
REGISTER_NOT_IMPLEMENTED_CLASS(DeepseekV32DecoderLayerImpl);
REGISTER_NOT_IMPLEMENTED_CLASS(LlamaDecoderLayerImpl);
REGISTER_NOT_IMPLEMENTED_CLASS(SiglipEncoderLayerImpl);
#endif

#if defined(USE_NPU)
#include "npu/npu_glm4_decoder_layer_impl.h"
#else
REGISTER_NOT_IMPLEMENTED_CLASS(Glm4DecoderLayerImpl);
#endif

#if defined(USE_NPU)
#include "npu/npu_glm4_vision_encoder_layer_impl.h"
namespace xllm {
namespace layer {
using Glm4VisionEncoderLayerImpl = NpuGlm4VisionEncoderLayerImpl;
}
} // namespace xllm
#else
REGISTER_NOT_IMPLEMENTED_CLASS(Glm4VisionEncoderLayerImpl);
#endif
REGISTER_NOT_IMPLEMENTED_CLASS(Glm4VisionEncoderLayerImpl);
4 changes: 1 addition & 3 deletions xllm/core/runtime/llm_worker_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class LLMWorkerImpl : public WorkerImpl {
model_->set_npu_word_embedding(embedding);
};

#else
#endif
layer::LmHead get_lm_head() { return model_->get_lm_head(); };

void set_lm_head(layer::LmHead& head) { model_->set_lm_head(head); };
Expand All @@ -72,8 +72,6 @@ class LLMWorkerImpl : public WorkerImpl {
model_->set_word_embedding(embedding);
};

#endif

private:
std::unique_ptr<BeamSearcher> beam_searcher_;
};
Expand Down
8 changes: 4 additions & 4 deletions xllm/core/runtime/speculative_worker_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,10 @@ bool SpeculativeWorkerImpl::init_model(const std::string& model_weights_path,
auto word_embedding = impl_->get_npu_word_embedding();
draft_impl_->set_npu_word_embedding(word_embedding);
} else {
// TODO: Support TORCH backend via torch_npu encapsulation in the future.
// Currently, it is explicitly disabled.
LOG(FATAL)
<< "SpeculativeWorkerImpl::init_model not support TORCH backend";
auto head = impl_->get_lm_head();
draft_impl_->set_lm_head(head);
auto word_embedding = impl_->get_word_embedding();
draft_impl_->set_word_embedding(word_embedding);
}
#else
auto head = impl_->get_lm_head();
Expand Down