45
45
from vllm .attention .selector import (_Backend , backend_name_to_enum ,
46
46
get_global_forced_attn_backend )
47
47
from vllm .config import CacheConfig , MultiModalConfig
48
- from vllm .distributed import parallel_state
48
+ from vllm .distributed import get_pp_group , parallel_state
49
49
from vllm .distributed import utils as dist_utils
50
50
from vllm .inputs import INPUT_REGISTRY , InputContext , LLMInputs
51
51
from vllm .logger import init_logger
68
68
from vllm .sequence import IntermediateTensors , SequenceData
69
69
from vllm .transformers_utils .processor import get_processor
70
70
71
+ from .utils import (PPMissingLayer , is_pp_missing_parameter ,
72
+ make_empty_intermediate_tensors_factory )
73
+
71
74
logger = init_logger (__name__ )
72
75
73
76
# === Vision Inputs === #
@@ -856,15 +859,21 @@ def __init__(self,
856
859
857
860
self .model = Qwen2Model (config , cache_config , quant_config )
858
861
859
- if config .tie_word_embeddings :
860
- self .lm_head = self .model .embed_tokens
862
+ if get_pp_group ().is_last_rank :
863
+ if config .tie_word_embeddings :
864
+ self .lm_head = self .model .embed_tokens
865
+ else :
866
+ self .lm_head = ParallelLMHead (config .vocab_size ,
867
+ config .hidden_size ,
868
+ quant_config = quant_config )
861
869
else :
862
- self .lm_head = ParallelLMHead (config .vocab_size ,
863
- config .hidden_size ,
864
- quant_config = quant_config )
870
+ self .lm_head = PPMissingLayer ()
865
871
866
872
self .logits_processor = LogitsProcessor (config .vocab_size )
867
873
self .sampler = Sampler ()
874
+ self .make_empty_intermediate_tensors = (
875
+ make_empty_intermediate_tensors_factory (
876
+ ["hidden_states" , "residual" ], config .hidden_size ))
868
877
869
878
def _validate_and_reshape_mm_tensor (self ,
870
879
mm_input : Union [torch .Tensor ,
@@ -979,7 +988,8 @@ def forward(
979
988
image_input = self ._parse_and_validate_image_input (** kwargs )
980
989
video_input = self ._parse_and_validate_video_input (** kwargs )
981
990
982
- if image_input is None and video_input is None :
991
+ if (image_input is None
992
+ and video_input is None ) or not get_pp_group ().is_first_rank :
983
993
inputs_embeds = None
984
994
else :
985
995
if getattr (self .config , "rope_scaling" , {}).get ("type" ,
@@ -1015,6 +1025,7 @@ def forward(
1015
1025
positions = positions ,
1016
1026
kv_caches = kv_caches ,
1017
1027
attn_metadata = attn_metadata ,
1028
+ intermediate_tensors = intermediate_tensors ,
1018
1029
inputs_embeds = inputs_embeds ,
1019
1030
)
1020
1031
return hidden_states
@@ -1055,6 +1066,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
1055
1066
# Skip loading extra bias for GPTQ models.
1056
1067
if name .endswith (".bias" ) and name not in params_dict :
1057
1068
continue
1069
+ if is_pp_missing_parameter (name , self ):
1070
+ continue
1058
1071
param = params_dict [name ]
1059
1072
weight_loader = param .weight_loader
1060
1073
weight_loader (param , loaded_weight , shard_id )
@@ -1081,6 +1094,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
1081
1094
# Skip loading extra bias for GPTQ models.
1082
1095
if name .endswith (".bias" ) and name not in params_dict :
1083
1096
continue
1097
+ if is_pp_missing_parameter (name , self ):
1098
+ continue
1084
1099
param = params_dict [name ]
1085
1100
except KeyError :
1086
1101
print (params_dict .keys ())
0 commit comments