Skip to content

Commit 048d84a

Browse files
authored
Merge pull request #594 from whiteswordLI/fix/lora-load-ddp-weights
Fix: support loading DDP-saved LoRA weights for inference
2 parents fe24501 + 3a18fdd commit 048d84a

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

model/model_lora.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ def forward_with_lora(x, layer1=original_forward, layer2=lora):
3434

3535
def load_lora(model, path):
3636
state_dict = torch.load(path, map_location=model.device)
37+
38+
# 兼容DDP训练保存的权重(带有module.前缀),去除前缀以匹配当前模型
39+
new_state_dict = {}
40+
for k, v in state_dict.items():
41+
if k.startswith('module.'):
42+
new_state_dict[k[7:]] = v
43+
else:
44+
new_state_dict[k] = v
45+
state_dict = new_state_dict
46+
3747
for name, module in model.named_modules():
3848
if hasattr(module, 'lora'):
3949
lora_state = {k.replace(f'{name}.lora.', ''): v for k, v in state_dict.items() if f'{name}.lora.' in k}

0 commit comments

Comments
 (0)