77class LlamaPreAndPostLayerWeight (PreAndPostLayerWeight ):
88 def __init__ (self , data_type , network_config , mode ):
99 super ().__init__ (data_type , network_config , mode )
10- self .enable_dp = os .getenv ("ENABLE_DP" , "0" ).upper () in ["ON" , "TRUE" , "1" ]
1110 return
1211
1312 def load_hf_weights (self , weights ):
@@ -16,18 +15,12 @@ def load_hf_weights(self, weights):
1615 split_start = split_indexes [self .tp_rank_ ]
1716 split_end = split_indexes [self .tp_rank_ + 1 ]
1817 if "model.embed_tokens.weight" in weights :
19- if self .enable_dp :
20- self .wte_weight_ = self ._cuda (weights ["model.embed_tokens.weight" ])
21- else :
22- self .wte_weight_ = self ._cuda (weights ["model.embed_tokens.weight" ][split_start :split_end , :])
18+ self .wte_weight_ = self ._cuda (weights ["model.embed_tokens.weight" ][split_start :split_end , :])
2319 tie_word_embeddings = self .network_config_ .get ("tie_word_embeddings" , False )
2420 if tie_word_embeddings :
2521 self .lm_head_weight_ = self .wte_weight_
2622 if "lm_head.weight" in weights :
27- if self .enable_dp :
28- self .lm_head_weight_ = self ._cuda (weights ["lm_head.weight" ])
29- else :
30- self .lm_head_weight_ = self ._cuda (weights ["lm_head.weight" ][split_start :split_end , :])
23+ self .lm_head_weight_ = self ._cuda (weights ["lm_head.weight" ][split_start :split_end , :])
3124 if "model.norm.weight" in weights :
3225 self .final_norm_weight_ = self ._cuda (weights ["model.norm.weight" ])
3326
0 commit comments