Skip to content

Commit a9c56b2

Browse files
committed
[fix] lora weight
1 parent 048d84a commit a9c56b2

File tree

1 file changed

+3
-10
lines changed

1 file changed

+3
-10
lines changed

model/model_lora.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,7 @@ def forward_with_lora(x, layer1=original_forward, layer2=lora):
3434

3535
def load_lora(model, path):
3636
state_dict = torch.load(path, map_location=model.device)
37-
38-
# 兼容DDP训练保存的权重(带有module.前缀),去除前缀以匹配当前模型
39-
new_state_dict = {}
40-
for k, v in state_dict.items():
41-
if k.startswith('module.'):
42-
new_state_dict[k[7:]] = v
43-
else:
44-
new_state_dict[k] = v
45-
state_dict = new_state_dict
37+
state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()}
4638

4739
for name, module in model.named_modules():
4840
if hasattr(module, 'lora'):
@@ -54,6 +46,7 @@ def save_lora(model, path):
5446
state_dict = {}
5547
for name, module in model.named_modules():
5648
if hasattr(module, 'lora'):
57-
lora_state = {f'{name}.lora.{k}': v for k, v in module.lora.state_dict().items()}
49+
clean_name = name[7:] if name.startswith("module.") else name
50+
lora_state = {f'{clean_name}.lora.{k}': v for k, v in module.lora.state_dict().items()}
5851
state_dict.update(lora_state)
5952
torch.save(state_dict, path)

0 commit comments

Comments
 (0)