Skip to content

Commit 7b453cd

Browse files
committed
fix: lora constants
Signed-off-by: Will Johnson <[email protected]>
1 parent f99ae71 commit 7b453cd

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed

plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def _infer_prefixes_and_module_names(
344344
):
345345
_name = "|".join([PARAM_NAME_ROUTER_SCATTERMOE, *PARAM_NAME_WEIGHT_SCATTERMOE])
346346
# pylint: disable=anomalous-backslash-in-string
347-
_reg = re.compile(f"(.*)\.({_name})\.weight")
347+
_reg = re.compile(rf"(.*)\.({_name})\.(?:weight|lora_A\.weight|lora_B\.weight)")
348348
found = {}
349349

350350
for k in sd_keys:

plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_constants.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
KEY_EXPERT_PARALLEL = "expert_parallel"
2525
DIM_EXPERT = 0
2626

27-
KEY_SCATTERMOE_ROUTER = PARAM_NAME_ROUTER_SCATTERMOE + ".weight"
27+
KEY_SCATTERMOE_ROUTER = "router.weight"
28+
KEY_SCATTERMOE_LORA_A_ROUTER = "router.lora_A.weight"
29+
KEY_SCATTERMOE_LORA_B_ROUTER = "router.lora_B.weight"
2830

2931
# Currently out ScatterMoE drop supports an up/down proj, and
3032
# and optional gate_proj.

plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
KEY_EXPERT_PARALLEL,
3535
KEY_REPLICATE,
3636
KEY_SCATTERMOE_ROUTER,
37+
KEY_SCATTERMOE_LORA_A_ROUTER,
38+
KEY_SCATTERMOE_LORA_B_ROUTER,
3739
get_scattermoe_conv_spec_from_archs,
3840
)
3941
from .scattermoe_state_dict import (
@@ -66,7 +68,7 @@ def _hook(grad):
6668

6769
for weight_name, param in state_dict.items():
6870

69-
if KEY_SCATTERMOE_ROUTER in weight_name:
71+
if KEY_SCATTERMOE_ROUTER in weight_name or KEY_SCATTERMOE_LORA_A_ROUTER in weight_name or KEY_SCATTERMOE_LORA_B_ROUTER in weight_name:
7072
# if its the router, replicate
7173
param = distribute_tensor(param, device_mesh, reps + [Replicate()])
7274
elif param.shape[0] > num_experts_per_device:

0 commit comments

Comments
 (0)