Skip to content

Commit d98b2c9

Browse files
committed
fix: logic for lora
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
1 parent f9176c5 commit d98b2c9

File tree

3 files changed

+26
-23
lines changed

3 files changed

+26
-23
lines changed

plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,19 @@ def recover_original_state_dict_from_checkpoint(
343343
# config
344344
config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path)
345345

346+
# if lora, check for input/output layers
347+
ip_op_layers = False
348+
router_layer = False
349+
if lora:
350+
for name, _ in sd.items():
351+
if "w1" in name:
352+
ip_op_layers = True
353+
break
354+
for name, _ in sd.items():
355+
if "router" in name:
356+
router_layer = True
357+
break
358+
346359
(
347360
_,
348361
router_name,
@@ -412,7 +425,8 @@ def _infer_prefixes_and_module_names(
412425
module_name,
413426
router_name,
414427
expert_name,
415-
lora_utils=lora,
428+
ip_op_layers=ip_op_layers,
429+
router_layer=router_layer,
416430
)
417431

418432
model2scatter = defaultdict(dict)

plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,6 @@ def prepare_scattermoe(
128128
# pylint: disable=import-outside-toplevel
129129
from .scattermoe import ScatterMoE
130130

131-
lora = False
132-
if lora_config:
133-
lora = True
134-
135131
if disable_distributed and ep_degree > 1:
136132
raise ValueError(
137133
"expert sharding can not be deferred to top level sharding"
@@ -255,7 +251,6 @@ def prepare_scattermoe(
255251
module_name,
256252
router_name,
257253
"|".join(expert_name),
258-
lora_start=lora,
259254
target_modules=lora_config.target_modules,
260255
)
261256

plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ def get_checkpoint_meta_from_sharded_safetensor(
8888
router_name: str = "gate", # e.g., named "gate" within block_sparse_moe
8989
expert_name: str = "experts", # e.g., named "experts" within block_sparse_moe
9090
expert_map: Dict = None, # map -> [w1,w2,w3]
91-
lora_start: bool = False, # if lora is detected in prepare_scattermoe.py
92-
lora_utils: bool = False, # if lora is detected in checkpoint_utils.py
91+
ip_op_layers: bool = False, # if input/output layers are detected in utils
92+
router_layer: bool = False, # if router layer is detected in utils
9393
target_modules: Dict = None, # target modules from prepare_scattermoe.py
9494
) -> Dict[str, List[Tuple]]:
9595
"""
@@ -111,6 +111,8 @@ def get_checkpoint_meta_from_sharded_safetensor(
111111
e.g., input_linear|output_linear|input_linear
112112
expert_map (dict): This is used with pattern ii) described above in expert_name.
113113
If not specified, will be the identity map, e.g., w1 -> w1
114+
lora_start (bool): Boolean to determine if lora is detected in scattermoe_prepare.py
115+
lora_utils (bool):
114116
"""
115117

116118
# insert in order
@@ -171,34 +173,26 @@ def _insert(L: List, i: int, v):
171173
f"'{router_name}' or expert_name '{expert_name}'"
172174
)
173175
if m.group(1) == router_name:
174-
if lora_utils:
176+
if router_layer:
175177
_map[KEY_SCATTERMOE_LORA_A_ROUTER].append((k, stfile))
176178
_map[KEY_SCATTERMOE_LORA_B_ROUTER].append((k, stfile))
177179
else:
178180
_map[KEY_SCATTERMOE_ROUTER].append((k, stfile))
179181
elif m.group(1) in expert_name:
180-
index = m.group(2)
181-
index = 0 if index is None else int(index)
182-
mod = None
183-
184-
# LoRA case
185182
if (
186183
"input_linear" in target_modules and "output_linear" in target_modules
187-
) or lora_utils:
188-
if not lora_utils:
184+
) or ip_op_layers:
185+
index = m.group(2)
186+
index = 0 if index is None else int(index)
187+
mod = None
188+
if not ip_op_layers:
189189
for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))):
190190
_insert(_map[f"{mod}.weight"], index, (k, stfile))
191191
else:
192192
for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))):
193193
_insert(_map[f"{mod}.lora_A"], index, (k, stfile))
194194
_insert(_map[f"{mod}.lora_B"], index, (k, stfile))
195-
196-
# Fine-tuning case
197-
elif not lora_utils and not lora_start:
198-
for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))):
199-
_insert(_map[f"{mod}.weight"], index, (k, stfile))
200-
201-
assert mod is not None, f"cannot map '{rel_k}'"
195+
assert mod is not None, f"cannot map '{rel_k}'"
202196

203197
if len(_map) == 0:
204198
raise ValueError(

0 commit comments

Comments
 (0)