Skip to content

Commit 7ad1368

Browse files
LagPixelLOLelvircrn
authored andcommitted
Optimized set_initialized_submodules. (huggingface#35493)
1 parent 03bf893 commit 7ad1368

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/transformers/modeling_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -565,13 +565,15 @@ def set_initialized_submodules(model, state_dict_keys):
565565
Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state
566566
dict.
567567
"""
568+
state_dict_keys = set(state_dict_keys)
568569
not_initialized_submodules = {}
569570
for module_name, module in model.named_modules():
570-
loaded_keys = {k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")}
571-
# When checking if the root module is loaded all state_dict_keys must be used.
572571
if module_name == "":
573-
loaded_keys = set(state_dict_keys)
574-
if loaded_keys.issuperset(module.state_dict()):
572+
# When checking if the root module is loaded there's no need to prepend module_name.
573+
module_keys = set(module.state_dict())
574+
else:
575+
module_keys = {f"{module_name}.{k}" for k in module.state_dict()}
576+
if module_keys.issubset(state_dict_keys):
575577
module._is_hf_initialized = True
576578
else:
577579
not_initialized_submodules[module_name] = module

0 commit comments

Comments
 (0)