Skip to content

Commit 417b3d1

Browse files
add qwen3 and qwen3_moe (#875)
Co-authored-by: baishihao <baishihao@sensetime.com> Co-authored-by: wangzaijun <wzjhelloworld@qq.com>
1 parent 66a5e8a commit 417b3d1

File tree

17 files changed

+389
-19
lines changed

17 files changed

+389
-19
lines changed

docs/CN/source/models/supported_models.rst

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,6 @@ lightllm 支持大多数的主流的开源大语言模型以及多模态模型
3232
- :code:`--eos_id 151643 --trust_remote_code`
3333
* - `ChatGLM2-6b <https://github.com/THUDM/ChatGLM2-6B>`_
3434
- :code:`--trust_remote_code`
35-
* - `Baichuan-7b <https://github.com/baichuan-inc/Baichuan-7B>`_
36-
- :code:`--trust_remote_code`
37-
* - `Baichuan-13b <https://github.com/baichuan-inc/Baichuan-13B>`_
38-
- :code:`--trust_remote_code`
39-
* - `Baichuan2-7b <https://github.com/baichuan-inc/Baichuan2>`_
40-
- :code:`--trust_remote_code`
41-
* - `Baichuan2-13b <https://github.com/baichuan-inc/Baichuan2>`_
42-
- :code:`--trust_remote_code`
4335
* - `InternLM-7b <https://github.com/InternLM/InternLM>`_
4436
- :code:`--trust_remote_code`
4537
* - `Yi-34b <https://huggingface.co/01-ai/Yi-34B>`_
@@ -58,6 +50,12 @@ lightllm 支持大多数的主流的开源大语言模型以及多模态模型
5850
- :code:`--data_type bfloat16`
5951
* - `DeepSeek-V2 <https://huggingface.co/deepseek-ai/DeepSeek-V2>`_
6052
- :code:`--data_type bfloat16`
53+
* - `DeepSeek-V3 <https://huggingface.co/deepseek-ai/DeepSeek-V2>`_
54+
-
55+
* - `Qwen3 <https://github.com/QwenLM/Qwen3>`_
56+
-
57+
* - `Qwen3-Moe <https://github.com/QwenLM/Qwen3>`_
58+
-
6159

6260

6361
多模态模型

docs/EN/source/models/supported_models.rst

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,6 @@ LLM
3131
- :code:`--eos_id 151643 --trust_remote_code`
3232
* - `ChatGLM2-6b <https://github.com/THUDM/ChatGLM2-6B>`_
3333
- :code:`--trust_remote_code`
34-
* - `Baichuan-7b <https://github.com/baichuan-inc/Baichuan-7B>`_
35-
- :code:`--trust_remote_code`
36-
* - `Baichuan-13b <https://github.com/baichuan-inc/Baichuan-13B>`_
37-
- :code:`--trust_remote_code`
38-
* - `Baichuan2-7b <https://github.com/baichuan-inc/Baichuan2>`_
39-
- :code:`--trust_remote_code`
40-
* - `Baichuan2-13b <https://github.com/baichuan-inc/Baichuan2>`_
41-
- :code:`--trust_remote_code`
4234
* - `InternLM-7b <https://github.com/InternLM/InternLM>`_
4335
- :code:`--trust_remote_code`
4436
* - `Yi-34b <https://huggingface.co/01-ai/Yi-34B>`_
@@ -57,6 +49,11 @@ LLM
5749
- :code:`--data_type bfloat16`
5850
* - `DeepSeek-V2 <https://huggingface.co/deepseek-ai/DeepSeek-V2>`_
5951
- :code:`--data_type bfloat16`
52+
* - `Qwen3 <https://github.com/QwenLM/Qwen3>`_
53+
-
54+
* - `Qwen3-Moe <https://github.com/QwenLM/Qwen3>`_
55+
-
56+
6057

6158

6259
VLM

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(
4545
self.e_score_correction_bias = None
4646
self.w2_list = [None] * self.n_routed_experts
4747
self.w2_scale_list = [None] * self.n_routed_experts
48-
self.scoring_func = network_config["scoring_func"]
48+
self.scoring_func = network_config.get("scoring_func", "softmax")
4949
self.w1 = [None, None] # weight, weight_scale
5050
self.w2 = [None, None] # weight, weight_scale
5151
self.lock = threading.Lock()

lightllm/models/llama/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,13 @@ def _verify_params(self):
8181
return
8282

8383
def _init_mem_manager(self):
84+
head_dim_ = self.config["hidden_size"] // self.config["num_attention_heads"]
85+
head_dim_ = self.config.get("head_dim", head_dim_)
8486
self.mem_manager = select_mem_manager_class(self.mode)(
8587
self.max_total_token_num,
8688
dtype=self.data_type,
8789
head_num=self.config["num_key_value_heads"] // self.tp_world_size_,
88-
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
90+
head_dim=head_dim_,
8991
layer_num=self.config["num_hidden_layers"],
9092
mem_fraction=self.mem_fraction,
9193
)

lightllm/models/qwen3/__init__.py

Whitespace-only changes.

lightllm/models/qwen3/layer_infer/__init__.py

Whitespace-only changes.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import os
2+
import torch
3+
import torch.functional as F
4+
import torch.distributed as dist
5+
import numpy as np
6+
import triton
7+
from typing import Tuple
8+
from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight
9+
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
10+
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
11+
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
12+
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
13+
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
14+
from functools import partial
15+
from lightllm.utils.log_utils import init_logger
16+
17+
logger = init_logger(__name__)
18+
19+
20+
class Qwen3TransformerLayerInfer(LlamaTransformerLayerInfer):
21+
def __init__(self, layer_num, network_config, mode=[]):
22+
super().__init__(layer_num, network_config, mode)
23+
self.head_dim_ = network_config["head_dim"]
24+
return
25+
26+
def _get_qkv(
27+
self,
28+
input: torch.Tensor,
29+
cache_kv,
30+
infer_state: LlamaInferStateInfo,
31+
layer_weight: Qwen3TransformerLayerWeight,
32+
) -> torch.Tensor:
33+
input = input.view(-1, self.embed_dim_)
34+
q = layer_weight.q_proj.mm(input)
35+
cache_kv = layer_weight.kv_proj.mm(
36+
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
37+
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
38+
rmsnorm_forward(
39+
q.reshape(-1, self.head_dim_),
40+
weight=layer_weight.q_norm_weight_.weight,
41+
eps=self.eps_,
42+
out=q.reshape(-1, self.head_dim_),
43+
)
44+
45+
rmsnorm_forward(
46+
cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, self.head_dim_),
47+
weight=layer_weight.k_norm_weight_.weight,
48+
eps=self.eps_,
49+
out=cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, self.head_dim_),
50+
)
51+
rotary_emb_fwd(
52+
q.view(-1, self.tp_q_head_num_, self.head_dim_),
53+
cache_kv[:, : self.tp_k_head_num_, :],
54+
infer_state.position_cos,
55+
infer_state.position_sin,
56+
)
57+
return q, cache_kv

lightllm/models/qwen3/layer_weights/__init__.py

Whitespace-only changes.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import os
2+
import torch
3+
import math
4+
import numpy as np
5+
from lightllm.common.basemodel import TransformerLayerWeight
6+
from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight
7+
from lightllm.utils.envs_utils import enable_env_vars
8+
from lightllm.common.basemodel.layer_weights.meta_weights import (
9+
ROWMMWeight,
10+
MultiROWMMWeight,
11+
COLMMWeight,
12+
NormWeight,
13+
FusedMoeWeightTP,
14+
FusedMoeWeightEP,
15+
ROWBMMWeight,
16+
)
17+
from functools import partial
18+
19+
20+
class Qwen3TransformerLayerWeight(LlamaTransformerLayerWeight):
21+
def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None):
22+
self.n_routed_experts = network_config["num_experts"]
23+
self.is_moe = (
24+
network_config["num_experts"] > 0
25+
and layer_num not in network_config["mlp_only_layers"]
26+
and (layer_num + 1) % network_config["decoder_sparse_step"] == 0
27+
)
28+
super().__init__(layer_num, data_type, network_config, mode, quant_cfg)
29+
return
30+
31+
def _init_weight_names(self):
32+
super()._init_weight_names()
33+
self._q_norm_name = f"model.layers.{self.layer_num_}.self_attn.q_norm.weight"
34+
self._k_norm_name = f"model.layers.{self.layer_num_}.self_attn.k_norm.weight"
35+
36+
def _init_norm(self):
37+
super()._init_norm()
38+
39+
self.q_norm_weight_ = NormWeight(weight_name=self._q_norm_name, data_type=self.data_type_)
40+
self.k_norm_weight_ = NormWeight(weight_name=self._k_norm_name, data_type=self.data_type_)

lightllm/models/qwen3/model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import torch
2+
from typing import final
3+
from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer
4+
from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight
5+
from lightllm.models.llama.model import LlamaTpPartModel
6+
from lightllm.utils.log_utils import init_logger
7+
8+
9+
logger = init_logger(__name__)
10+
11+
12+
class Qwen3TpPartModel(LlamaTpPartModel):
13+
# weight class
14+
transformer_weight_class = Qwen3TransformerLayerWeight
15+
16+
# infer class
17+
transformer_layer_infer_class = Qwen3TransformerLayerInfer
18+
19+
def __init__(self, kvargs):
20+
super().__init__(kvargs)
21+
return

0 commit comments

Comments
 (0)