@@ -88,8 +88,8 @@ def get_checkpoint_meta_from_sharded_safetensor(
8888 router_name : str = "gate" , # e.g., named "gate" within block_sparse_moe
8989 expert_name : str = "experts" , # e.g., named "experts" within block_sparse_moe
9090 expert_map : Dict = None , # map -> [w1,w2,w3]
91- lora_start : bool = False , # if lora is detected in prepare_scattermoe.py
92- lora_utils : bool = False , # if lora is detected in checkpoint_utils.py
91+ ip_op_layers : bool = False , # if input/output layers are detected in utils
92+ router_layer : bool = False , # if router layer is detected in utils
9393 target_modules : Dict = None , # target modules from prepare_scattermoe.py
9494) -> Dict [str , List [Tuple ]]:
9595 """
@@ -111,6 +111,8 @@ def get_checkpoint_meta_from_sharded_safetensor(
111111 e.g., input_linear|output_linear|input_linear
112112 expert_map (dict): This is used with pattern ii) described above in expert_name.
113113 If not specified, will be the identity map, e.g., w1 -> w1
114+ lora_start (bool): Boolean to determine if lora is detected in scattermoe_prepare.py
115+ lora_utils (bool):
114116 """
115117
116118 # insert in order
@@ -171,34 +173,26 @@ def _insert(L: List, i: int, v):
171173 f"'{ router_name } ' or expert_name '{ expert_name } '"
172174 )
173175 if m .group (1 ) == router_name :
174- if lora_utils :
176+ if router_layer :
175177 _map [KEY_SCATTERMOE_LORA_A_ROUTER ].append ((k , stfile ))
176178 _map [KEY_SCATTERMOE_LORA_B_ROUTER ].append ((k , stfile ))
177179 else :
178180 _map [KEY_SCATTERMOE_ROUTER ].append ((k , stfile ))
179181 elif m .group (1 ) in expert_name :
180- index = m .group (2 )
181- index = 0 if index is None else int (index )
182- mod = None
183-
184- # LoRA case
185182 if (
186183 "input_linear" in target_modules and "output_linear" in target_modules
187- ) or lora_utils :
188- if not lora_utils :
184+ ) or ip_op_layers :
185+ index = m .group (2 )
186+ index = 0 if index is None else int (index )
187+ mod = None
188+ if not ip_op_layers :
189189 for mod in expert_map .get (m .group (1 ), expert_map .get (m .group (3 ))):
190190 _insert (_map [f"{ mod } .weight" ], index , (k , stfile ))
191191 else :
192192 for mod in expert_map .get (m .group (1 ), expert_map .get (m .group (3 ))):
193193 _insert (_map [f"{ mod } .lora_A" ], index , (k , stfile ))
194194 _insert (_map [f"{ mod } .lora_B" ], index , (k , stfile ))
195-
196- # Fine-tuning case
197- elif not lora_utils and not lora_start :
198- for mod in expert_map .get (m .group (1 ), expert_map .get (m .group (3 ))):
199- _insert (_map [f"{ mod } .weight" ], index , (k , stfile ))
200-
201- assert mod is not None , f"cannot map '{ rel_k } '"
195+ assert mod is not None , f"cannot map '{ rel_k } '"
202196
203197 if len (_map ) == 0 :
204198 raise ValueError (
0 commit comments