Skip to content

Commit 6ae1e1c

Browse files
committed
feat: handle lora A and B for converting checkpoint
Signed-off-by: Will Johnson <[email protected]>
1 parent 25e9155 commit 6ae1e1c

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,20 @@ def _infer_prefixes_and_module_names(
461461

462462
if len(scatter_keys) == 1:
463463
sd[model_key] = scatter_params[scatter_keys[0]]
464+
465+
elif any("lora_A" in k for k in scatter_keys) and any("lora_B" in k for k in scatter_keys):
466+
lora_A_key = next((k for k in scatter_keys if "lora_A" in k), None)
467+
lora_B_key = next((k for k in scatter_keys if "lora_B" in k), None)
468+
469+
if lora_A_key and lora_B_key:
470+
lora_A = scatter_params[lora_A_key]
471+
lora_B = scatter_params[lora_B_key]
472+
473+
# Multiply matrices
474+
lora_weight = torch.matmul(lora_B, lora_A)
475+
476+
sd[model_key] = lora_weight
477+
464478
else:
465479
# unfortunately, there this is a in
466480
# scattermoe_state_dict._maybe_reshape_scattermoe_expert_weights

0 commit comments

Comments
 (0)