File tree Expand file tree Collapse file tree 4 files changed +18
-6
lines changed
Expand file tree Collapse file tree 4 files changed +18
-6
lines changed Original file line number Diff line number Diff line change @@ -306,6 +306,7 @@ class ModelConfig:
306306
307307 # fp32 lm head
308308 enforce_fp32_head : bool = False
309+ tie_word_embeddings : bool = False
309310
310311 def get_head_size (self ):
311312 """Get head size."""
@@ -357,7 +358,10 @@ def from_pretrained(
357358 enforce_fp32_head = hf_overrides .pop ('enforce_fp32_head' , False )
358359 override_hf_config (model_config .hf_config , hf_overrides )
359360
361+ # for fp32 head
360362 model_config .enforce_fp32_head = enforce_fp32_head
363+ model_config .tie_word_embeddings = getattr (hf_config , 'tie_word_embeddings' , False )
364+
361365 # for serialization of transformers modules
362366 maybe_register_config_serialize_by_value (trust_remote_code )
363367
Original file line number Diff line number Diff line change @@ -1040,11 +1040,14 @@ def _build_model(self):
10401040 # for router replay
10411041 enable_return_routed_experts = self .misc_config .enable_return_routed_experts and self .need_output
10421042
1043- build_model_ctx = BuildModelContext (disable_vision_encoder = self .misc_config .disable_vision_encoder ,
1044- dllm_config = self .misc_config .dllm_config ,
1045- strategy_factory = self .strategy_factory ,
1046- enable_return_routed_experts = enable_return_routed_experts ,
1047- enforce_fp32_head = self .model_config .enforce_fp32_head )
1043+ build_model_ctx = BuildModelContext (
1044+ disable_vision_encoder = self .misc_config .disable_vision_encoder ,
1045+ dllm_config = self .misc_config .dllm_config ,
1046+ strategy_factory = self .strategy_factory ,
1047+ enable_return_routed_experts = enable_return_routed_experts ,
1048+ enforce_fp32_head = self .model_config .enforce_fp32_head ,
1049+ tie_word_embeddings = self .model_config .tie_word_embeddings ,
1050+ )
10481051 patched_model = build_patched_model (self .model_config ,
10491052 device = device ,
10501053 model_format = self .misc_config .model_format ,
Original file line number Diff line number Diff line change @@ -390,6 +390,7 @@ class BuildModelContext:
390390 strategy_factory : 'StrategyFactoryBase' = None
391391 enable_return_routed_experts : bool = False
392392 enforce_fp32_head : bool = False
393+ tie_word_embeddings : bool = False
393394
394395
395396class StepContextManager (CtxMgrBase [StepContext ]):
Original file line number Diff line number Diff line change @@ -126,7 +126,11 @@ def build_embedding(vocab_size: int,
126126 """Build embedding."""
127127 bm_ctx = get_build_model_context ()
128128
129- force_dtype = torch .float32 if bm_ctx .enforce_fp32_head else None
129+ # run with fp32 only when share weights with lm_head
130+ force_dtype = None
131+ if bm_ctx .enforce_fp32_head and bm_ctx .tie_word_embeddings :
132+ force_dtype = torch .float32
133+
130134 return ParallelEmbedding (
131135 vocab_size ,
132136 hidden_size ,
You can’t perform that action at this time.
0 commit comments