File tree Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -565,13 +565,15 @@ def set_initialized_submodules(model, state_dict_keys):
565565 Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state
566566 dict.
567567 """
568+ state_dict_keys = set (state_dict_keys )
568569 not_initialized_submodules = {}
569570 for module_name , module in model .named_modules ():
570- loaded_keys = {k .replace (f"{ module_name } ." , "" ) for k in state_dict_keys if k .startswith (f"{ module_name } ." )}
571- # When checking if the root module is loaded all state_dict_keys must be used.
572571 if module_name == "" :
573- loaded_keys = set (state_dict_keys )
574- if loaded_keys .issuperset (module .state_dict ()):
572+ # When checking if the root module is loaded there's no need to prepend module_name.
573+ module_keys = set (module .state_dict ())
574+ else :
575+ module_keys = {f"{ module_name } .{ k } " for k in module .state_dict ()}
576+ if module_keys .issubset (state_dict_keys ):
575577 module ._is_hf_initialized = True
576578 else :
577579 not_initialized_submodules [module_name ] = module
You can’t perform that action at this time.
0 commit comments