Skip to content

Commit 08d4ad2

Browse files
authored
fix: lora safe tensors should have the HF peft support format (#166)
* fix: lora checkpoint Signed-off-by: Mehant Kammakomati <[email protected]> * fix: lora checkpoint Signed-off-by: Mehant Kammakomati <[email protected]> * fix: lora checkpoint Signed-off-by: Mehant Kammakomati <[email protected]> * fix: lora checkpoint Signed-off-by: Mehant Kammakomati <[email protected]> * fix: lora checkpoint Signed-off-by: Mehant Kammakomati <[email protected]> --------- Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent d36f3b0 commit 08d4ad2

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,25 +646,41 @@ def recover_safetensors_from_dcp(
646646
# create switch based on state dict for future use
647647
new_state_dict = {}
648648
lora = False
649+
lora_keys = {}
649650
for name, param in state_dict.items():
650651
# if lora weight, set lora switch to true
651652
if "lora_A" in name or "lora_B" in name:
652653
lora = True
653654
# if lora naming convention, convert to traditional
654655
if "base_model.model." in name:
656+
v = name
655657
name = name.replace("base_model.model.", "", 1)
658+
if "default." in name:
659+
name = name.replace("default.", "", 1)
660+
k = name
661+
lora_keys[k] = v
656662
if "default." in name:
663+
v = name
657664
name = name.replace("default.", "", 1)
665+
k = name
666+
lora_keys[k] = v
658667
new_state_dict[name] = param
659668

660669
# recover the original state dict
661670
state_dict = recover_original_state_dict_from_checkpoint(
662671
new_state_dict, _name_or_path
663672
)
664673

674+
new_state_dict = {}
675+
# modify the state dict back to HF PEFT format
676+
for name, param in state_dict.items():
677+
if lora_keys.get(name, None):
678+
name = lora_keys[name]
679+
new_state_dict[name] = param
680+
665681
# save it as a safetensors file
666682
save_sharded_safetensors(
667-
{k: v.contiguous() for k, v in state_dict.items()},
683+
{k: v.contiguous() for k, v in new_state_dict.items()},
668684
output_dir,
669685
metadata={"format": "pt"},
670686
lora=lora,

0 commit comments

Comments
 (0)