Skip to content

Commit f9abe9c

Browse files
committed
update TP and model load
1 parent 20bd1ea commit f9abe9c

File tree

4 files changed

+12
-8
lines changed

4 files changed

+12
-8
lines changed

paddlenlp/generation/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ def generate(
742742
# ['是的', '嗯嗯']
743743
"""
744744
if generation_config is None:
745-
if self.generation_config._from_model_config:
745+
if self.generation_config is None or self.generation_config._from_model_config:
746746
new_generation_config = GenerationConfig.from_model_config(self.config)
747747
if new_generation_config != self.generation_config:
748748
logger.warning(

paddlenlp/transformers/conversion_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1319,7 +1319,7 @@ def _resolve_prefix_keys(state_keys_base, state_keys_real, ignore_error=False):
13191319
for x in state_keys_real:
13201320
if x.endswith(key):
13211321
state_keys_map[key] = x
1322-
break
1322+
# break # remove break for math A.key B.key ...
13231323
if key not in state_keys_map:
13241324
if not ignore_error:
13251325
logger.debug(f"tensor parallel conversion: could not find name {key} in loaded state dict!")

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,9 @@ def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False):
856856

857857
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, self.hidden_size, has_bias=config.attention_bias, input_is_parallel=True)
858858

859+
assert self.num_heads % config.tensor_parallel_degree == 0, f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
860+
self.num_heads = self.num_heads // config.tensor_parallel_degree
861+
859862
else:
860863
# for without tensor parallel
861864
if self.q_lora_rank is None:
@@ -1228,12 +1231,15 @@ def get_tensor_parallel_split_mappings(num_layers):
12281231
# Column Linear
12291232
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
12301233
base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True)
1234+
base_actions["layers.0.self_attn.q_b_proj.weight"] = partial(fn, is_column=True)
1235+
12311236
# if we have enough num_key_value_heads to split, then split it.
12321237
if config.num_key_value_heads % config.tensor_parallel_degree == 0:
12331238
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
12341239
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)
12351240
base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True)
12361241
base_actions["layers.0.self_attn.v_proj.bias"] = partial(fn, is_column=True)
1242+
base_actions["layers.0.self_attn.kv_b_proj.weight"] = partial(fn, is_column=True)
12371243

12381244
base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True)
12391245
base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True)
@@ -1625,9 +1631,7 @@ def forward(self, hidden_states, tensor_parallel_output=None):
16251631
if tensor_parallel_output is None:
16261632
tensor_parallel_output = self.config.tensor_parallel_output
16271633

1628-
logits = parallel_matmul(
1629-
hidden_states, self.weight, transpose_y=False, tensor_parallel_output=tensor_parallel_output
1630-
)
1634+
logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output)
16311635
return logits
16321636

16331637

@@ -1639,7 +1643,7 @@ def __init__(self, config: DeepseekV2Config):
16391643
self.config = config
16401644
self.deepseek_v2 = DeepseekV2Model(config)
16411645
self.vocab_size = config.vocab_size
1642-
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias_attr=False)
1646+
self.lm_head = DeepSeekV2LMHead(config)
16431647
self.criterion = DeepSeekV2PretrainingCriterion(config)
16441648

16451649
def get_input_embeddings(self):

paddlenlp/transformers/deepseek_v3/modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
from typing import List, Optional, Tuple, Union
2525

2626
import paddle
27-
from paddle import nn
2827

2928
from ..deepseek_v2.modeling import (
3029
DeepseekV2ForSequenceClassification,
30+
DeepSeekV2LMHead,
3131
DeepseekV2Model,
3232
DeepseekV2PretrainedModel,
3333
DeepSeekV2PretrainingCriterion,
@@ -63,7 +63,7 @@ def __init__(self, config: DeepseekV2Config):
6363
super().__init__(config)
6464
self.deepseek_v3 = DeepseekV3Model(config)
6565
self.vocab_size = config.vocab_size
66-
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias_attr=False)
66+
self.lm_head = DeepSeekV2LMHead(config)
6767
self.criterion = DeepSeekV2PretrainingCriterion(config)
6868

6969
def get_input_embeddings(self):

0 commit comments

Comments
 (0)