Skip to content

Commit 3386bd3

Browse files
committed
emb to fp32 only when share with lm_head
1 parent 5d03573 commit 3386bd3

File tree

4 files changed

+18
-6
lines changed

4 files changed

+18
-6
lines changed

lmdeploy/pytorch/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ class ModelConfig:
306306

307307
# fp32 lm head
308308
enforce_fp32_head: bool = False
309+
tie_word_embeddings: bool = False
309310

310311
def get_head_size(self):
311312
"""Get head size."""
@@ -357,7 +358,10 @@ def from_pretrained(
357358
enforce_fp32_head = hf_overrides.pop('enforce_fp32_head', False)
358359
override_hf_config(model_config.hf_config, hf_overrides)
359360

361+
# for fp32 head
360362
model_config.enforce_fp32_head = enforce_fp32_head
363+
model_config.tie_word_embeddings = getattr(hf_config, 'tie_word_embeddings', False)
364+
361365
# for serialization of transformers modules
362366
maybe_register_config_serialize_by_value(trust_remote_code)
363367

lmdeploy/pytorch/engine/model_agent/agent.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,11 +1040,14 @@ def _build_model(self):
10401040
# for router replay
10411041
enable_return_routed_experts = self.misc_config.enable_return_routed_experts and self.need_output
10421042

1043-
build_model_ctx = BuildModelContext(disable_vision_encoder=self.misc_config.disable_vision_encoder,
1044-
dllm_config=self.misc_config.dllm_config,
1045-
strategy_factory=self.strategy_factory,
1046-
enable_return_routed_experts=enable_return_routed_experts,
1047-
enforce_fp32_head=self.model_config.enforce_fp32_head)
1043+
build_model_ctx = BuildModelContext(
1044+
disable_vision_encoder=self.misc_config.disable_vision_encoder,
1045+
dllm_config=self.misc_config.dllm_config,
1046+
strategy_factory=self.strategy_factory,
1047+
enable_return_routed_experts=enable_return_routed_experts,
1048+
enforce_fp32_head=self.model_config.enforce_fp32_head,
1049+
tie_word_embeddings=self.model_config.tie_word_embeddings,
1050+
)
10481051
patched_model = build_patched_model(self.model_config,
10491052
device=device,
10501053
model_format=self.misc_config.model_format,

lmdeploy/pytorch/model_inputs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,7 @@ class BuildModelContext:
390390
strategy_factory: 'StrategyFactoryBase' = None
391391
enable_return_routed_experts: bool = False
392392
enforce_fp32_head: bool = False
393+
tie_word_embeddings: bool = False
393394

394395

395396
class StepContextManager(CtxMgrBase[StepContext]):

lmdeploy/pytorch/models/utils/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,11 @@ def build_embedding(vocab_size: int,
126126
"""Build embedding."""
127127
bm_ctx = get_build_model_context()
128128

129-
force_dtype = torch.float32 if bm_ctx.enforce_fp32_head else None
129+
# run with fp32 only when share weights with lm_head
130+
force_dtype = None
131+
if bm_ctx.enforce_fp32_head and bm_ctx.tie_word_embeddings:
132+
force_dtype = torch.float32
133+
130134
return ParallelEmbedding(
131135
vocab_size,
132136
hidden_size,

0 commit comments

Comments
 (0)