Skip to content

Commit b82939d

Browse files
authored
fix save tensor dtype (#2642)
Co-authored-by: llbdyiu66 <[email protected]>
1 parent 0a22611 commit b82939d

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

paddleformers/mergekit/merge_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ def shard_lora_merge(self, base_index, shard_file, lora_config, file_type_list,
590590
tensor = paddle.Tensor.__call__(tensor, zero_copy=True)
591591
lora_A_tensor = paddle.Tensor.__call__(lora_A_tensor, zero_copy=True)
592592
lora_B_tensor = paddle.Tensor.__call__(lora_B_tensor, zero_copy=True)
593-
if self.is_cpu and is_bf16 or self.merge_config.save_to_hf:
593+
if self.is_cpu and is_bf16:
594594
tensor = tensor.astype("float32")
595595
lora_A_tensor = lora_A_tensor.astype("float32")
596596
lora_B_tensor = lora_B_tensor.astype("float32")

paddleformers/transformers/model_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def load_state_dict(
598598
def prepare_safe_save_state_dict(state_dict, save_to_hf=False):
599599
for k in list(state_dict.keys()):
600600
if isinstance(state_dict[k], paddle.Tensor):
601-
if save_to_hf:
601+
if state_dict[k].dtype == paddle.bfloat16:
602602
state_dict[k] = state_dict.pop(k).astype("float32").cpu().numpy().astype(ml_dtypes.bfloat16)
603603
else:
604604
state_dict[k] = state_dict.pop(k).cpu().numpy()

0 commit comments

Comments
 (0)