Skip to content

Commit 19376dd

Browse files
committed
lint
Signed-off-by: Alexandros Koumparoulis <[email protected]>
1 parent 027a206 commit 19376dd

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

nemo_automodel/components/checkpoint/checkpointing.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
if TYPE_CHECKING:
4949
from peft import PeftConfig
50+
5051
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
5152

5253

@@ -313,9 +314,7 @@ def load_model(
313314

314315
# For models that need tensor merging and don't have an adapter, try using transformers' conversion
315316
if is_init_step and model_type and requires_tensor_merging(model_type) and not has_state_dict_adapter:
316-
converted_state_dict = _convert_checkpoint_with_transformers(
317-
model_state.model[0], model_path, key_mapping
318-
)
317+
converted_state_dict = _convert_checkpoint_with_transformers(model_state.model[0], model_path, key_mapping)
319318
if converted_state_dict is not None:
320319
# Load using full_state_dict=True to properly convert tensors to DTensors for FSDP
321320
_load_full_state_dict_into_model(model_state.model, converted_state_dict)
@@ -887,7 +886,8 @@ def _load_full_state_dict_into_model(
887886
state_dict: Full state dict with regular tensors
888887
"""
889888
from functools import partial
890-
from torch.distributed.checkpoint.state_dict import set_model_state_dict, StateDictOptions
889+
890+
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
891891

892892
# Use full_state_dict=True to tell PyTorch this is a complete, non-sharded state dict
893893
# It will properly shard the tensors to match the model's DTensor layout
@@ -925,16 +925,17 @@ def _convert_checkpoint_with_transformers(
925925
Converted state dict ready for loading, or None if conversion failed.
926926
"""
927927
try:
928+
from copy import deepcopy
929+
930+
from safetensors import safe_open
931+
928932
from transformers.conversion_mapping import get_model_conversion_mapping
929933
from transformers.core_model_loading import (
930934
WeightConverter,
931935
WeightRenaming,
932-
rename_source_key,
933936
dot_natural_key,
937+
rename_source_key,
934938
)
935-
from safetensors import safe_open
936-
from copy import deepcopy
937-
from collections import defaultdict
938939
except ImportError:
939940
logging.warning(
940941
"transformers library with conversion_mapping not available. "
@@ -946,7 +947,9 @@ def _convert_checkpoint_with_transformers(
946947
# Get the weight conversion mapping from transformers
947948
weight_mapping = get_model_conversion_mapping(model, key_mapping=key_mapping, add_legacy=True)
948949
if not weight_mapping:
949-
logging.warning(f"No conversion mapping found for model type {getattr(model.config, 'model_type', 'unknown')}")
950+
logging.warning(
951+
f"No conversion mapping found for model type {getattr(model.config, 'model_type', 'unknown')}"
952+
)
950953
return None
951954

952955
# Load the safetensors files
@@ -962,9 +965,6 @@ def _convert_checkpoint_with_transformers(
962965
for key in f.keys():
963966
checkpoint_state_dict[key] = f.get_tensor(key)
964967

965-
# Get model's expected keys
966-
model_state_dict = model.state_dict()
967-
968968
# Separate renamings and converters
969969
renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)]
970970
converters = [entry for entry in weight_mapping if isinstance(entry, WeightConverter)]
@@ -1010,6 +1010,7 @@ def _convert_checkpoint_with_transformers(
10101010
except Exception as e:
10111011
logging.warning(f"Failed to convert checkpoint with transformers: {e}")
10121012
import traceback
1013+
10131014
traceback.print_exc()
10141015
return None
10151016

0 commit comments

Comments
 (0)