diff --git a/xllm/core/framework/model/causal_lm.h b/xllm/core/framework/model/causal_lm.h index efa85c15e..ce276242e 100644 --- a/xllm/core/framework/model/causal_lm.h +++ b/xllm/core/framework/model/causal_lm.h @@ -20,9 +20,6 @@ limitations under the License. #include "graph/types.h" #include "layers/npu/npu_lm_head_impl.h" #include "layers/npu/npu_word_embedding_impl.h" -#else -#include "layers/lm_head.h" -#include "layers/word_embedding.h" #endif // clang-format on #include @@ -34,12 +31,13 @@ limitations under the License. #include "core/framework/model_loader.h" #include "core/framework/quant_args.h" #include "core/framework/state_dict/state_dict.h" +#include "layers/lm_head.h" +#include "layers/word_embedding.h" #include "model_args.h" #include "model_input_params.h" namespace xllm { -#if !defined(USE_NPU) namespace detail { template struct has_get_lm_head : std::false_type {}; @@ -76,7 +74,6 @@ struct has_set_word_embedding< std::void_t()->set_word_embedding( std::declval()))>> : std::true_type {}; } // namespace detail -#endif class CausalLM : public torch::nn::Module { public: @@ -113,7 +110,7 @@ class CausalLM : public torch::nn::Module { virtual void set_npu_lm_head(layer::NpuLmHead& head) = 0; virtual layer::NpuWordEmbedding get_npu_word_embedding() = 0; virtual void set_npu_word_embedding(layer::NpuWordEmbedding& embedding) = 0; -#else +#endif virtual layer::LmHead get_lm_head() { LOG(FATAL) << "Method 'get_lm_head' is not implemented/supported by this model."; @@ -130,7 +127,6 @@ class CausalLM : public torch::nn::Module { LOG(FATAL) << "Method 'set_word_embedding' is not implemented/supported by " "this model."; } -#endif }; template @@ -180,7 +176,7 @@ class CausalLMImpl : public CausalLM { void set_npu_word_embedding(layer::NpuWordEmbedding& embedding) override { model_->set_npu_word_embedding(embedding); } -#else +#endif layer::LmHead get_lm_head() override { if constexpr (detail::has_get_lm_head::value) { return model_->get_lm_head(); @@ -212,7 +208,6 @@ class CausalLMImpl : public CausalLM { CausalLM::set_word_embedding(embedding); } } -#endif torch::Device device() const override { return options_.device(); } diff --git a/xllm/core/framework/model/causal_vlm.h b/xllm/core/framework/model/causal_vlm.h index f00e30aa7..497b66883 100644 --- a/xllm/core/framework/model/causal_vlm.h +++ b/xllm/core/framework/model/causal_vlm.h @@ -79,7 +79,7 @@ class CausalVLMImpl : public CausalVLM { void set_npu_word_embedding(layer::NpuWordEmbedding& embedding) override { model_->set_npu_word_embedding(embedding); } -#else +#endif layer::LmHead get_lm_head() override { if constexpr (detail::has_get_lm_head::value) { return model_->get_lm_head(); @@ -111,7 +111,6 @@ class CausalVLMImpl : public CausalVLM { CausalLM::set_word_embedding(embedding); } } -#endif torch::Device device() const override { return options_.device(); } diff --git a/xllm/core/framework/model/mm_embedding_vlm.h b/xllm/core/framework/model/mm_embedding_vlm.h index aa5256f6d..bd8774314 100644 --- a/xllm/core/framework/model/mm_embedding_vlm.h +++ b/xllm/core/framework/model/mm_embedding_vlm.h @@ -72,12 +72,11 @@ class MMEmbeddingVLMImpl : public MMEmbeddingVLM { virtual void set_npu_word_embedding(layer::NpuWordEmbedding& embedding) { return; } -#else +#endif virtual void set_lm_head(layer::LmHead& head) { return; } virtual layer::LmHead get_lm_head() { return nullptr; } virtual layer::WordEmbedding get_word_embedding() { return nullptr; } virtual void set_word_embedding(layer::WordEmbedding& embedding) { return; } -#endif void load_model(std::unique_ptr loader) override { model_->load_model(std::move(loader)); diff --git a/xllm/core/layers/config.h b/xllm/core/layers/config.h index 51b6dabbc..46504de0f 100644 --- a/xllm/core/layers/config.h +++ b/xllm/core/layers/config.h @@ -39,103 +39,26 @@ limitations under the License. } \ } -#if defined(USE_NPU) -#include "npu/npu_word_embedding_impl.h" -#else -#include "common/word_embedding_impl.h" -#endif - -#if defined(USE_NPU) -#include "npu/npu_pos_embedding_impl.h" -#else -#include "common/rotary_embedding.h" -#endif - -#if defined(USE_NPU) -#include "npu/npu_lm_head_impl.h" -#else #include "common/linear.h" -UNIFY_CLASS_NAME(ColumnParallelLinearImpl, LmHeadImpl) -#endif +#include "common/qwen2_5_vision_layer.h" +#include "common/qwen2_decoder_layer.h" +#include "common/qwen3_moe_decoder_layer.h" +#include "common/rotary_embedding.h" +#include "common/word_embedding_impl.h" -#if defined(USE_NPU) -#include "npu/npu_deepseek_v2_decoder_layer_impl.h" -#elif defined(USE_MLU) +#if defined(USE_MLU) #include "mlu/deepseek_v2_decoder_layer_impl.h" #else REGISTER_NOT_IMPLEMENTED_CLASS(DeepseekV2DecoderLayerImpl); #endif -#if defined(USE_NPU) -#include "npu/npu_deepseek_v32_decoder_layer_impl.h" -#else -REGISTER_NOT_IMPLEMENTED_CLASS(DeepseekV32DecoderLayerImpl); -#endif - -#if defined(USE_NPU) -#include "npu/npu_llama_decoder_layer_impl.h" -#else -REGISTER_NOT_IMPLEMENTED_CLASS(LlamaDecoderLayerImpl); -#endif - -#if defined(USE_NPU) -#include "npu/npu_qwen2_decoder_layer_impl.h" -#else -#include "common/qwen2_decoder_layer.h" -#endif - -#if defined(USE_NPU) -#include "npu/npu_qwen2_vision_encoder_layer_impl.h" -#else -#include "common/qwen2_5_vision_layer.h" +UNIFY_CLASS_NAME(ColumnParallelLinearImpl, LmHeadImpl) UNIFY_CLASS_NAME(Qwen2_VisionLayerImpl, Qwen2VisionEncoderLayerImpl) -#endif - -#if defined(USE_NPU) -#include "npu/npu_qwen2dot5_vision_encoder_layer_impl.h" -#else -#include "common/qwen2_5_vision_layer.h" UNIFY_CLASS_NAME(Qwen2_5_VisionLayerImpl, Qwen2dot5VisionEncoderLayerImpl) -#endif - -#if defined(USE_NPU) -#include "npu/npu_qwen3_decoder_layer_impl.h" -#else -#include "common/qwen2_decoder_layer.h" -#endif - -#if defined(USE_NPU) -#include "npu/npu_qwen3_moe_decoder_layer_impl.h" -#else -#include "common/qwen3_moe_decoder_layer.h" -#endif - -#if defined(USE_NPU) -#include "npu/npu_qwen3_vision_encoder_layer_impl.h" -#else -#include "common/qwen2_5_vision_layer.h" UNIFY_CLASS_NAME(Qwen3_VisionLayerImpl, Qwen3VisionEncoderLayerImpl) -#endif -#if defined(USE_NPU) -#include "npu/npu_siglip_encoder_layer_impl.h" -#else +REGISTER_NOT_IMPLEMENTED_CLASS(DeepseekV32DecoderLayerImpl); +REGISTER_NOT_IMPLEMENTED_CLASS(LlamaDecoderLayerImpl); REGISTER_NOT_IMPLEMENTED_CLASS(SiglipEncoderLayerImpl); -#endif - -#if defined(USE_NPU) -#include "npu/npu_glm4_decoder_layer_impl.h" -#else REGISTER_NOT_IMPLEMENTED_CLASS(Glm4DecoderLayerImpl); -#endif - -#if defined(USE_NPU) -#include "npu/npu_glm4_vision_encoder_layer_impl.h" -namespace xllm { -namespace layer { -using Glm4VisionEncoderLayerImpl = NpuGlm4VisionEncoderLayerImpl; -} -} // namespace xllm -#else -REGISTER_NOT_IMPLEMENTED_CLASS(Glm4VisionEncoderLayerImpl); -#endif +REGISTER_NOT_IMPLEMENTED_CLASS(Glm4VisionEncoderLayerImpl); \ No newline at end of file diff --git a/xllm/core/runtime/llm_worker_impl.h b/xllm/core/runtime/llm_worker_impl.h index cfc64cdb3..73e58ee19 100644 --- a/xllm/core/runtime/llm_worker_impl.h +++ b/xllm/core/runtime/llm_worker_impl.h @@ -59,7 +59,7 @@ class LLMWorkerImpl : public WorkerImpl { model_->set_npu_word_embedding(embedding); }; -#else +#endif layer::LmHead get_lm_head() { return model_->get_lm_head(); }; void set_lm_head(layer::LmHead& head) { model_->set_lm_head(head); }; @@ -72,8 +72,6 @@ class LLMWorkerImpl : public WorkerImpl { model_->set_word_embedding(embedding); }; -#endif - private: std::unique_ptr beam_searcher_; }; diff --git a/xllm/core/runtime/speculative_worker_impl.cpp b/xllm/core/runtime/speculative_worker_impl.cpp index 1e2aeeed3..fa3c598bd 100644 --- a/xllm/core/runtime/speculative_worker_impl.cpp +++ b/xllm/core/runtime/speculative_worker_impl.cpp @@ -195,10 +195,10 @@ bool SpeculativeWorkerImpl::init_model(const std::string& model_weights_path, auto word_embedding = impl_->get_npu_word_embedding(); draft_impl_->set_npu_word_embedding(word_embedding); } else { - // TODO: Support TORCH backend via torch_npu encapsulation in the future. - // Currently, it is explicitly disabled. - LOG(FATAL) - << "SpeculativeWorkerImpl::init_model not support TORCH backend"; + auto head = impl_->get_lm_head(); + draft_impl_->set_lm_head(head); + auto word_embedding = impl_->get_word_embedding(); + draft_impl_->set_word_embedding(word_embedding); } #else auto head = impl_->get_lm_head();