|
| 1 | +/* Copyright 2025 The xLLM Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + https://github.com/jd-opensource/xllm/blob/main/LICENSE |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. |
| 14 | +==============================================================================*/ |
| 15 | + |
| 16 | +#pragma once |
| 17 | + |
| 18 | +#include "core/layers/common/rotary_embedding_util.h" |
| 19 | +#include "deepseek_v2.h" |
| 20 | +#include "mtp_model_base.h" |
| 21 | + |
| 22 | +// DeepSeek v2 compatible with huggingface weights |
| 23 | +// ref to: |
| 24 | +// https://github.com/vllm-project/vllm/blob/v0.6.6/vllm/model_executor/models/deepseek_v2.py |
| 25 | + |
| 26 | +namespace xllm { |
| 27 | + |
| 28 | +class DeepseekMtpModelImpl : public MtpModelImplBase<DeepseekV2DecoderLayer> { |
| 29 | + public: |
| 30 | + DeepseekMtpModelImpl(const ModelContext& context) |
| 31 | + : MtpModelImplBase<DeepseekV2DecoderLayer>("deepseek_v3_mtp", context) { |
| 32 | + auto model_args = context.get_model_args(); |
| 33 | + auto options = context.get_tensor_options(); |
| 34 | + |
| 35 | + int32_t mask_value = model_args.dtype() == "bfloat16" ? 1 : -9984; |
| 36 | + attn_mask_ = layer::AttentionMask(options.device(), |
| 37 | + options.dtype().toScalarType(), |
| 38 | + /*mask_value=*/mask_value); |
| 39 | + |
| 40 | + cos_sin_ = layer::rotary::get_deepseek_rotary_embedding( |
| 41 | + model_args.qk_rope_head_dim(), |
| 42 | + model_args.qk_rope_head_dim(), |
| 43 | + model_args.max_position_embeddings(), |
| 44 | + model_args.rope_scaling_original_max_position_embeddings(), |
| 45 | + model_args.rope_theta(), |
| 46 | + /*interleaved*/ false, |
| 47 | + model_args.rope_scaling_factor(), |
| 48 | + model_args.rope_extrapolation_factor(), |
| 49 | + model_args.rope_scaling_attn_factor(), |
| 50 | + model_args.rope_scaling_beta_fast(), |
| 51 | + model_args.rope_scaling_beta_slow(), |
| 52 | + model_args.rope_scaling_mscale(), |
| 53 | + model_args.rope_scaling_mscale_all_dim(), |
| 54 | + options); |
| 55 | + } |
| 56 | +}; |
| 57 | +TORCH_MODULE(DeepseekMtpModel); |
| 58 | + |
| 59 | +class DeepseekMtpForCausalLMImpl |
| 60 | + : public MtpForCausalLMImplBase<DeepseekMtpModel> { |
| 61 | + public: |
| 62 | + DeepseekMtpForCausalLMImpl(const ModelContext& context) |
| 63 | + : MtpForCausalLMImplBase<DeepseekMtpModel>(context) {} |
| 64 | +}; |
| 65 | +TORCH_MODULE(DeepseekMtpForCausalLM); |
| 66 | + |
| 67 | +// register the causal model |
| 68 | +REGISTER_CAUSAL_MODEL(deepseek_v3_mtp, DeepseekMtpForCausalLM); |
| 69 | + |
| 70 | +// example config: |
| 71 | +// https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/config.json |
| 72 | +REGISTER_MODEL_ARGS(deepseek_v3_mtp, [&] { |
| 73 | + LOAD_ARG_OR(model_type, "model_type", "deepseek_v3_mtp"); |
| 74 | + LOAD_ARG_OR(dtype, "torch_dtype", ""); |
| 75 | + LOAD_ARG_OR(vocab_size, "vocab_size", 129280); |
| 76 | + LOAD_ARG_OR(hidden_size, "hidden_size", 7168); |
| 77 | + LOAD_ARG_OR(n_layers, "num_hidden_layers", 61); |
| 78 | + LOAD_ARG_OR(n_heads, "num_attention_heads", 128); |
| 79 | + LOAD_ARG_OR(n_kv_heads, "num_key_value_heads", 128); |
| 80 | + LOAD_ARG_OR(intermediate_size, "intermediate_size", 18432); |
| 81 | + LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 163840); |
| 82 | + LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-6); |
| 83 | + LOAD_ARG_OR(eos_token_id, "eos_token_id", 1); |
| 84 | + LOAD_ARG_OR(bos_token_id, "bos_token_id", 0); |
| 85 | + LOAD_ARG_OR(rope_theta, "rope_theta", 10000.0f); |
| 86 | + LOAD_ARG_OR(use_sliding_window, "use_sliding_window", false); |
| 87 | + LOAD_ARG_OR(sliding_window, "sliding_window", 4096); |
| 88 | + LOAD_ARG_OR(max_window_layers, "max_window_layers", 61); |
| 89 | + |
| 90 | + LOAD_ARG_OR(first_k_dense_replace, "first_k_dense_replace", 0); |
| 91 | + LOAD_ARG_OR(moe_layer_freq, "moe_layer_freq", 1); |
| 92 | + LOAD_ARG_OR(topk_method, "topk_method", "noaux_tc"); |
| 93 | + LOAD_ARG_OR(n_routed_experts, "n_routed_experts", 256); |
| 94 | + LOAD_ARG_OR(n_shared_experts, "n_shared_experts", 1); |
| 95 | + LOAD_ARG_OR(num_experts_per_tok, "num_experts_per_tok", 8); |
| 96 | + LOAD_ARG_OR(moe_intermediate_size, "moe_intermediate_size", 2048); |
| 97 | + LOAD_ARG_OR(routed_scaling_factor, "routed_scaling_factor", 2.5f); |
| 98 | + LOAD_ARG_OR(norm_topk_prob, "norm_topk_prob", true); |
| 99 | + LOAD_ARG_OR(n_group, "n_group", 8); |
| 100 | + LOAD_ARG_OR(topk_group, "topk_group", 4); |
| 101 | + LOAD_ARG_OR(qk_nope_head_dim, "qk_nope_head_dim", 128); |
| 102 | + LOAD_ARG_OR(qk_rope_head_dim, "qk_rope_head_dim", 64); |
| 103 | + LOAD_ARG_OR(v_head_dim, "v_head_dim", 128); |
| 104 | + LOAD_ARG_OR(q_lora_rank, "q_lora_rank", 1536); |
| 105 | + LOAD_ARG_OR(kv_lora_rank, "kv_lora_rank", 512); |
| 106 | + |
| 107 | + LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] { |
| 108 | + return 256; // args->qk_nope_head_dim() + args->qk_rope_head_dim(); |
| 109 | + }); |
| 110 | + LOAD_ARG_OR_FUNC( |
| 111 | + rotary_dim, "rotary_dim", [&] { return args->qk_rope_head_dim(); }); |
| 112 | + |
| 113 | + SET_ARG(rope_scaling_rope_type, "deepseek_yarn"); |
| 114 | + LOAD_ARG(rope_scaling_beta_fast, "rope_scaling.beta_fast"); |
| 115 | + LOAD_ARG(rope_scaling_beta_slow, "rope_scaling.beta_slow"); |
| 116 | + LOAD_ARG(rope_scaling_factor, "rope_scaling.factor"); |
| 117 | + LOAD_ARG_OR( |
| 118 | + rope_extrapolation_factor, "rope_scaling.extrapolation_factor", 1.0f); |
| 119 | + LOAD_ARG(rope_scaling_mscale, "rope_scaling.mscale"); |
| 120 | + LOAD_ARG(rope_scaling_mscale_all_dim, "rope_scaling.mscale_all_dim"); |
| 121 | + LOAD_ARG(rope_scaling_original_max_position_embeddings, |
| 122 | + "rope_scaling.original_max_position_embeddings"); |
| 123 | + LOAD_ARG_OR(rope_scaling_attn_factor, "rope_scaling.attn_factor", 1.0f); |
| 124 | + |
| 125 | + SET_ARG(stop_token_ids, std::unordered_set<int32_t>({1})); |
| 126 | +}); |
| 127 | +} // namespace xllm |
0 commit comments