1313# limitations under the License.
1414
1515# Standard
16+ from collections import defaultdict
17+ from typing import List
18+ import json
1619import os
1720import re
18- import json
19- import torch
20-
21- from .scattermoe_constants import (
22- FILE_SAFETENSOR_INDEX ,
23- PARAM_NAME_WEIGHT_SCATTERMOE ,
24- PARAM_NAME_ROUTER_SCATTERMOE ,
25- get_scattermoe_conv_spec_from_archs ,
26- )
27-
28- from .scattermoe_state_dict import get_checkpoint_meta_from_sharded_safetensor
2921
3022# Third Party
31- from transformers import PretrainedConfig
3223from accelerate .logging import get_logger
3324from accelerate .utils .constants import FSDP_MODEL_NAME , OPTIMIZER_NAME
3425from torch .distributed .checkpoint .default_planner import (
3728)
3829from torch .distributed .checkpoint .state_dict import get_state_dict , set_state_dict
3930from torch .distributed .fsdp .fully_sharded_data_parallel import StateDictType
31+ from transformers import PretrainedConfig
32+ import torch
4033import torch .distributed .checkpoint as dcp
4134
42- from typing import List
35+ # Local
36+ from .scattermoe_constants import (
37+ FILE_SAFETENSOR_INDEX ,
38+ PARAM_NAME_ROUTER_SCATTERMOE ,
39+ PARAM_NAME_WEIGHT_SCATTERMOE ,
40+ get_scattermoe_conv_spec_from_archs ,
41+ )
42+ from .scattermoe_state_dict import get_checkpoint_meta_from_sharded_safetensor
4343
4444logger = get_logger (__name__ )
4545
@@ -181,6 +181,7 @@ def patch_huggingface_save_and_load_for_dtensors():
181181 patch_target_module ("transformers.trainer.load_fsdp_model" , load_fsdp_model )
182182 patch_target_module ("transformers.trainer.load_fsdp_optimizer" , load_fsdp_optimizer )
183183
184+
184185# trick to get the resolved cache file to acccess the safetensor
185186# NOTE: this does not work if _dict_from_json_file, like GGUF files
186187def get_resolved_checkpoint_location (model_name_or_path : str ):
@@ -199,26 +200,29 @@ def _dict_from_json_file(resolved_config_file):
199200 PretrainedConfig ._dict_from_json_file = _old_func
200201 return os .path .dirname (result )
201202
203+
202204def restore_scattermoe_checkpoint_to_orig (
203205 dcp_checkpoint_dir : str ,
204206 pretrained_model_name_or_path : str = None ,
205- dcp_outer_key : str = ' model' ,
207+ dcp_outer_key : str = " model" ,
206208):
207-
208209 """
209210 Parameters:
210211 dcp_checkpoint_dir (str): the dcp to be converted.
211- pretrained_model_name_or_path (str): Optional, if provided we will
212- use the hints to remap the
212+ pretrained_model_name_or_path (str): Optional, if provided we will
213+ use the hints to remap the
213214 """
214215
215216 # reference dcp_to_torch_save from torch.distributed.checkpoint.format_utils.py
216217 # - strategy is to use _EmptyStateDictLoadPlanner to populate the state dict, then we remap
217218
218219 # guarded, load some internal functions
220+ # pylint: disable=import-outside-toplevel
221+ # Third Party
219222 from torch .distributed .checkpoint .default_planner import _EmptyStateDictLoadPlanner
220- from torch .distributed .checkpoint .state_dict_loader import _load_state_dict
221223 from torch .distributed .checkpoint .metadata import STATE_DICT_TYPE
224+ from torch .distributed .checkpoint .state_dict_loader import _load_state_dict
225+
222226 sd : STATE_DICT_TYPE = {}
223227 _load_state_dict (
224228 sd ,
@@ -228,7 +232,7 @@ def restore_scattermoe_checkpoint_to_orig(
228232 )
229233 sd = sd [dcp_outer_key ]
230234
231- # if not provided
235+ # if not provided
232236 if pretrained_model_name_or_path is None :
233237 return sd
234238
@@ -237,18 +241,16 @@ def restore_scattermoe_checkpoint_to_orig(
237241 with open (os .path .join (loc , FILE_SAFETENSOR_INDEX ), encoding = "utf-8" ) as f :
238242 index = json .load (f )
239243
240- # config
244+ # config
241245 config = PretrainedConfig .from_pretrained (pretrained_model_name_or_path )
242246
243247 (
244- moe_cls ,
248+ _ ,
245249 router_name ,
246250 expert_name ,
247- expert_mlp_spec ,
251+ __ ,
248252 sharded_expert_ckpt ,
249- ) = get_scattermoe_conv_spec_from_archs (
250- config .architectures
251- )
253+ ) = get_scattermoe_conv_spec_from_archs (config .architectures )
252254
253255 # the sd from the module swap must have keys like
254256 # 'model.layers.0.block_sparse_moe.w1.weight'
@@ -258,13 +260,12 @@ def restore_scattermoe_checkpoint_to_orig(
258260 # prefix = model.layers.0 and module_name = block_sparse_moe
259261
260262 def _infer_prefixes_and_module_names (
261- sd_keys : List [str ], min_count : int = 3 ,
263+ sd_keys : List [str ],
264+ min_count : int = 3 ,
262265 ):
263- _name = "|" .join ([
264- PARAM_NAME_ROUTER_SCATTERMOE ,
265- * PARAM_NAME_WEIGHT_SCATTERMOE
266- ])
267- _reg = re .compile (f'(.*)\.({ _name } )\.weight' )
266+ _name = "|" .join ([PARAM_NAME_ROUTER_SCATTERMOE , * PARAM_NAME_WEIGHT_SCATTERMOE ])
267+ # pylint: disable=anomalous-backslash-in-string
268+ _reg = re .compile (f"(.*)\.({ _name } )\.weight" )
268269 found = {}
269270
270271 for k in sd_keys :
@@ -276,22 +277,22 @@ def _infer_prefixes_and_module_names(
276277 found [prefix ] = 1 + found .get (prefix , 0 )
277278
278279 results = []
279- for prefix in found .keys ():
280+ for prefix , cnt in found .items ():
280281 # if at least router, w1 and w2 are found, take it
281282 # otherwise we delete
282- if found [ prefix ] >= min_count :
283+ if cnt >= min_count :
283284 results .append (prefix )
284285
285286 return results
286287
287288 for prefix in _infer_prefixes_and_module_names (sd .keys ()):
288- prefix = prefix .split ('.' )
289+ prefix = prefix .split ("." )
289290 prefix , module_name = "." .join (prefix [:- 1 ]), prefix [- 1 ]
290291
291292 # checkpoint metadata is will be a map
292293 # key -> list of tuples
293294 # where each in the list is (param_name, stfile)
294- # - if the list is larger than one, it means that the
295+ # - if the list is larger than one, it means that the
295296 # actual model has a sharded checkpoint
296297
297298 # defaultdict(list,
@@ -312,22 +313,12 @@ def _infer_prefixes_and_module_names(
312313 expert_name ,
313314 )
314315
315- # check if expert name has repeats
316- # has_repeat = False
317- # _seen = set()
318- # for k in expert_name.split('|'):
319- # if k in _seen:
320- # has_repeat = True
321- # break
322- # _seen.add(k)
323-
324- from collections import defaultdict
325316 model2scatter = defaultdict (dict )
326317 # construct a map of model_key -> {scatter_key: [params, ...]}
327- # - if the param list > 1, that means many scatter keys map to 1
318+ # - if the param list > 1, that means many scatter keys map to 1
328319 # model param and they need to be cat
329320 for scatter_key , list_of_params in checkpoint_metadata .items ():
330- scatter_key_fqdn = '.' .join ([prefix , module_name , scatter_key ])
321+ scatter_key_fqdn = "." .join ([prefix , module_name , scatter_key ])
331322 scatter_param = sd [scatter_key_fqdn ]
332323
333324 # remove from state dict
@@ -341,13 +332,11 @@ def _infer_prefixes_and_module_names(
341332 else :
342333 # if sharded, we just assume that there should be 1 expert
343334 # per shard
344- assert n == scatter_param .shape [0 ], \
345- "Sharded expert weights should be 1 expert per shard."
335+ assert (
336+ n == scatter_param .shape [0 ]
337+ ), "Sharded expert weights should be 1 expert per shard."
346338
347- if any (
348- scatter_key .startswith (k ) for k in
349- PARAM_NAME_WEIGHT_SCATTERMOE
350- ):
339+ if any (scatter_key .startswith (k ) for k in PARAM_NAME_WEIGHT_SCATTERMOE ):
351340 scatter_param = scatter_param .permute (0 , 2 , 1 )
352341
353342 # go through all the model keys
@@ -363,33 +352,33 @@ def _infer_prefixes_and_module_names(
363352 model2scatter [model_key ][scatter_key ] = _param
364353
365354 # replace them back in the sd
366- for model_key in list (model2scatter .keys ()):
355+ for model_key in list (model2scatter .keys ()):
367356
368357 scatter_params = model2scatter [model_key ]
369358
370359 # - there is an assumption that the ifthere is a cat, then
371360 # it will go by order of scatter keys
372361 scatter_keys = sorted (scatter_params .keys ())
373362
374- assert len (scatter_keys ) > 0 , f"Obtained zero scatter keys for model_key \' { model_key } \' "
363+ assert (
364+ len (scatter_keys ) > 0
365+ ), f"Obtained zero scatter keys for model_key '{ model_key } '"
375366
376367 if len (scatter_keys ) == 1 :
377368 sd [model_key ] = scatter_params [scatter_keys [0 ]]
378369 else :
379- # unfortunately, there this is a in
370+ # unfortunately, there this is a in
380371 # scattermoe_state_dict._maybe_reshape_scattermoe_expert_weights
381372 # that we split on the dim=1, so we cat back on that
382373 sd [model_key ] = torch .cat (
383- [scatter_params [k ] for k in scatter_keys ],
384- dim = 1
374+ [scatter_params [k ] for k in scatter_keys ], dim = 1
385375 )
386376
387377 # remove from this intemediate mapping
388378 del model2scatter [model_key ]
389379
390- rem_keys = "," .join ([k for k in model2scatter ])
391- assert len (rem_keys ) == 0 , \
392- f"Did not handle model parameters \' { rem_keys } \' "
380+ rem_keys = "," .join (list (model2scatter ))
381+ assert len (rem_keys ) == 0 , f"Did not handle model parameters '{ rem_keys } '"
393382
394383 return sd
395384
@@ -398,7 +387,8 @@ def _infer_prefixes_and_module_names(
398387
399388
400389# have it serve as a conversion script
401- if __name__ == '__main__' :
390+ if __name__ == "__main__" :
391+ # Standard
402392 import argparse
403393
404394 parser = argparse .ArgumentParser (
@@ -417,7 +407,5 @@ def _infer_prefixes_and_module_names(
417407 "In order to reconstruct the state dict, we requre hints from "
418408 "the original pretrained model checkpoint (from which this "
419409 "checkpoint is obtained)."
420- )
410+ ),
421411 )
422-
423-
0 commit comments