Skip to content

Commit 42af0b4

Browse files
authored
[V1 Loader] Support DeepSeekV3(bf16) (#3294)
* Support new loader for DeepSeekV3(bf16) * update paddle version * remove useless attr
1 parent e0aeac5 commit 42af0b4

File tree

5 files changed

+141
-5
lines changed

5 files changed

+141
-5
lines changed

fastdeploy/model_executor/layers/linear.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,7 @@ def __init__(
720720
self.v_head_dim = v_head_dim
721721
# Split num_attention_heads when using TP inference.
722722
self.num_heads_per_partition = divide(num_attention_heads, self.nranks)
723+
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
723724

724725
# Initialize parent with combined dimensions
725726
super().__init__(
@@ -738,6 +739,63 @@ def __init__(
738739
self.k_weight_key = f"{prefix.replace('kv_b_proj', 'k_b_proj')}.weight"
739740
self.v_weight_key = f"{prefix.replace('kv_b_proj', 'v_b_proj')}.weight"
740741

742+
self.k_b_proj_weight = self.create_parameter(
743+
shape=[self.num_heads_per_partition, self.qk_nope_head_dim, self.kv_lora_rank],
744+
dtype=self.weight_dtype,
745+
is_bias=False,
746+
default_initializer=paddle.nn.initializer.Constant(0),
747+
)
748+
749+
self.v_b_proj_weight = self.create_parameter(
750+
shape=[self.num_heads_per_partition, self.kv_lora_rank, self.v_head_dim],
751+
dtype=self.weight_dtype,
752+
is_bias=False,
753+
default_initializer=paddle.nn.initializer.Constant(0),
754+
)
755+
756+
set_weight_attrs(
757+
self.k_b_proj_weight,
758+
{"weight_loader": self.weight_loader},
759+
)
760+
761+
if self.nranks > 0:
762+
_set_var_distributed(self.k_b_proj_weight, split_axis=1)
763+
set_weight_attrs(self.k_b_proj_weight, {"output_dim": True})
764+
765+
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
766+
output_dim = getattr(param, "output_dim", None)
767+
# Tensor parallelism splits the weight along the output_dim
768+
if output_dim is not None:
769+
dim = -1
770+
size = loaded_weight.get_shape()[dim]
771+
block_size = size // self.nranks
772+
shard_offset = self.local_rank * block_size
773+
shard_size = (self.local_rank + 1) * block_size
774+
loaded_weight = loaded_weight[..., shard_offset:shard_size]
775+
w = (
776+
get_tensor(loaded_weight)
777+
.reshape(
778+
[
779+
self.kv_lora_rank,
780+
self.num_heads_per_partition,
781+
-1,
782+
]
783+
)
784+
.transpose(perm=[1, 2, 0])
785+
)
786+
if param.dtype != w.dtype:
787+
w = w.cast(param.dtype)
788+
# Split into K and V weights
789+
# wk_b: [num_heads, qk_nope_head_dim, kv_lora_rank]
790+
wk_b = w[:, : self.qk_nope_head_dim, :]
791+
if self.v_head_dim is None:
792+
raise ValueError("self.v_head_dim should not be None")
793+
# wv_b: [num_heads, kv_lora_rank, v_head_dim]
794+
wv_b = w[:, -self.v_head_dim :, :].transpose(perm=[0, 2, 1])
795+
796+
self.k_b_proj_weight.set_value(wk_b)
797+
self.v_b_proj_weight.set_value(wv_b)
798+
741799
def load_state_dict(self, state_dict: dict):
742800
"""
743801
Load the combined KV weight and split it into K and V projections

fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def get_moe_scores(
5252
compute moe scores using e_score_correction_bias.
5353
"""
5454
scores = paddle.nn.functional.sigmoid(gating_output)
55-
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
55+
scores_with_bias = scores + e_score_correction_bias
5656
scores, topk_values, topk_idx = noaux_tc(
5757
scores,
5858
scores_with_bias,

fastdeploy/model_executor/layers/moe/moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,10 +508,11 @@ def load_state_dict(self, state_dict, is_rearrange: bool = False):
508508
gate_correction_bias_tensor = self.extract_gate_correction_bias(
509509
self.gate_correction_bias_key, state_dict
510510
)
511+
if self.gate_correction_bias.shape != gate_correction_bias_tensor.shape:
512+
gate_correction_bias_tensor = gate_correction_bias_tensor.reshape(self.gate_correction_bias.shape)
511513
self.gate_correction_bias.set_value(gate_correction_bias_tensor)
512514
else:
513515
self.gate_correction_bias = None
514-
515516
else:
516517
self.gate_correction_bias = None
517518

fastdeploy/model_executor/models/deepseek_v3.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,79 @@ def set_state_dict(self, state_dict):
628628
self.model.load_state_dict(state_dict)
629629
self.lm_head.load_state_dict(state_dict)
630630

631+
@paddle.no_grad()
632+
def load_weights(self, weights_iterator) -> None:
633+
"""
634+
Load model parameters from a given weights_iterator object.
635+
Args:
636+
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
637+
"""
638+
from fastdeploy.model_executor.models.utils import default_weight_loader
639+
640+
stacked_params_mapping = [
641+
# (param_name, shard_name, shard_id)
642+
("up_gate_proj", "gate_proj", "gate"),
643+
("up_gate_proj", "up_proj", "up"),
644+
("embed_tokens.embeddings", "embed_tokens", None),
645+
("lm_head.linear", "lm_head", None),
646+
("experts.gate_correction_bias", "gate.e_score_correction_bias", None),
647+
]
648+
# (param_name, weight_name, expert_id, shard_id)
649+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
650+
ckpt_gate_proj_name="gate_proj",
651+
ckpt_down_proj_name="down_proj",
652+
ckpt_up_proj_name="up_proj",
653+
param_gate_up_proj_name="experts.up_gate_proj_",
654+
param_down_proj_name="experts.down_proj_",
655+
num_experts=self.fd_config.model_config.n_routed_experts,
656+
)
657+
params_dict = dict(self.named_parameters())
658+
659+
for loaded_weight_name, loaded_weight in weights_iterator:
660+
loaded_weight_name = loaded_weight_name.replace("deepseek_v3", "model")
661+
loaded_weight_name = loaded_weight_name.replace("layers", "decoder_layers")
662+
663+
for param_name, weight_name, shard_id in stacked_params_mapping:
664+
if weight_name not in loaded_weight_name:
665+
continue
666+
if "mlp.experts." in loaded_weight_name and loaded_weight_name not in params_dict:
667+
continue
668+
model_param_name = loaded_weight_name.replace(weight_name, param_name)
669+
670+
if model_param_name not in params_dict:
671+
continue
672+
673+
param = params_dict[model_param_name]
674+
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
675+
weight_loader(param, loaded_weight, shard_id)
676+
break
677+
else:
678+
for mapping in expert_params_mapping:
679+
param_name, weight_name, expert_id, shard_id = mapping
680+
if weight_name not in loaded_weight_name:
681+
continue
682+
model_param_name = loaded_weight_name.replace(weight_name, param_name)
683+
if model_param_name not in params_dict:
684+
continue
685+
param = params_dict[model_param_name]
686+
weight_loader = param.weight_loader
687+
weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id)
688+
break
689+
else:
690+
if loaded_weight_name not in params_dict:
691+
continue
692+
param = params_dict[loaded_weight_name]
693+
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
694+
weight_loader(param, loaded_weight)
695+
if "kv_b_proj.weight" in loaded_weight_name:
696+
# handle kv_b_proj_bmm
697+
model_param_name = loaded_weight_name.replace(
698+
"kv_b_proj.weight", "kv_b_proj_bmm.k_b_proj_weight"
699+
)
700+
param = params_dict[model_param_name]
701+
weight_loader = getattr(param, "weight_loader", None)
702+
weight_loader(param, loaded_weight, shard_id)
703+
631704
def compute_logits(self, hidden_states: paddle.Tensor):
632705
""" """
633706
logits = self.lm_head(hidden_states)

fastdeploy/model_executor/models/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,13 @@ def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None):
7878
if param.dtype != loaded_weight.dtype:
7979
loaded_weight = loaded_weight.cast(param.dtype)
8080

81-
assert param.shape == loaded_weight.shape, (
82-
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
83-
)
81+
if param.shape != loaded_weight.shape:
82+
try:
83+
param = param.reshape(loaded_weight.shape)
84+
except ValueError as e:
85+
raise ValueError(
86+
f" Attempted to load weight ({loaded_weight.shape}) into parameter ({param.shape}). {e}"
87+
)
8488

8589
param.copy_(loaded_weight, False)
8690
except Exception:

0 commit comments

Comments
 (0)