|
18 | 18 | import os |
19 | 19 | import re |
20 | 20 | from collections import defaultdict, namedtuple |
21 | | -from contextlib import contextmanager |
22 | 21 | from functools import lru_cache |
23 | 22 | from pathlib import Path |
24 | | -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union |
| 23 | +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union |
25 | 24 |
|
26 | 25 | from packaging import version |
27 | 26 |
|
@@ -538,13 +537,15 @@ def _load_sharded_checkpoint( |
538 | 537 | for shard_file in shard_files: |
539 | 538 | # Load shard into memory |
540 | 539 | shard_path = os.path.join(save_directory, shard_file) |
541 | | - with _load_shard_into_memory( |
| 540 | + state_dict = load_state_dict_from_file( |
542 | 541 | shard_path, |
543 | | - load_fn=load_state_dict_from_file, |
544 | | - kwargs={"weights_only": weights_only}, |
545 | | - ) as state_dict: |
546 | | - # Update model with parameters from this shard |
547 | | - model.load_state_dict(state_dict, strict=strict) |
| 542 | + map_location="cpu", |
| 543 | + weights_only=weights_only, |
| 544 | + ) |
| 545 | + # Update model with parameters from this shard |
| 546 | + model.load_state_dict(state_dict, strict=strict) |
| 547 | + # Explicitly remove the state dict from memory |
| 548 | + del state_dict |
548 | 549 |
|
549 | 550 | # 4. Return compatibility info |
550 | 551 | loaded_keys = set(index["weight_map"].keys()) |
@@ -630,7 +631,8 @@ def load_state_dict_from_file( |
630 | 631 | # Check format of the archive |
631 | 632 | with safe_open(checkpoint_file, framework="pt") as f: # type: ignore[attr-defined] |
632 | 633 | metadata = f.metadata() |
633 | | - if metadata.get("format") != "pt": |
| 634 | + # see comment: https://github.com/huggingface/transformers/blob/3d213b57fe74302e5902d68ed9478c3ad1aaa713/src/transformers/modeling_utils.py#L3966 |
| 635 | + if metadata is not None and metadata.get("format") not in ["pt", "mlx"]: |
634 | 636 | raise OSError( |
635 | 637 | f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " |
636 | 638 | "you save your model with the `save_torch_model` method." |
@@ -668,30 +670,6 @@ def load_state_dict_from_file( |
668 | 670 | # HELPERS |
669 | 671 |
|
670 | 672 |
|
671 | | -@contextmanager |
672 | | -def _load_shard_into_memory( |
673 | | - shard_path: str, |
674 | | - load_fn: Callable, |
675 | | - kwargs: Optional[Dict[str, Any]] = None, |
676 | | -): |
677 | | - """ |
678 | | - Context manager to handle loading and cleanup of model shards. |
679 | | -
|
680 | | - Args: |
681 | | - shard_path: Path to the shard file |
682 | | - load_fn: Function to load the shard (either torch.load or safetensors.load) |
683 | | -
|
684 | | - Yields: |
685 | | - The loaded state dict for this shard |
686 | | - """ |
687 | | - try: |
688 | | - state_dict = load_fn(shard_path, **kwargs) # type: ignore[arg-type] |
689 | | - yield state_dict |
690 | | - finally: |
691 | | - # Explicitly remove the state dict from memory |
692 | | - del state_dict |
693 | | - |
694 | | - |
695 | 673 | def _validate_keys_for_strict_loading( |
696 | 674 | model: "torch.nn.Module", |
697 | 675 | loaded_keys: Iterable[str], |
|
0 commit comments