Skip to content

Commit afb0fa8

Browse files
authored
Merge pull request #23 from grok-ai/feature/ckpt-substitute
Add option to substitute keys/values in old checkpoints
2 parents f97aa83 + ea568bf commit afb0fa8

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

src/nn_core/serialization.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import collections
12
import importlib
23
import inspect
34
import logging
@@ -17,6 +18,7 @@
1718

1819
pylogger = logging.getLogger(__name__)
1920

21+
from typing import Mapping
2022

2123
_METADATA_MODULE_KEY = f"{METADATA_KEY}_module"
2224
_METADATA_CLASS_KEY = f"{METADATA_KEY}_class"
@@ -124,14 +126,41 @@ def extract_checkpoint(ckpt_file: Path) -> Path:
124126
yield Path(tmp_dir)
125127

126128

129+
def _substistute(dictionary, substitute_values: Dict[str, str], substitute_keys: Dict[str, str] = {}):
130+
if not isinstance(dictionary, Mapping):
131+
if isinstance(dictionary, collections.Hashable):
132+
if substitute_values is not None and dictionary in substitute_values:
133+
return substitute_values[dictionary]
134+
elif substitute_keys is not None and dictionary in substitute_keys:
135+
return substitute_keys[dictionary]
136+
else:
137+
return dictionary
138+
return dictionary
139+
140+
return {
141+
_substistute(key, substitute_values=substitute_values, substitute_keys=substitute_keys,): _substistute(
142+
value,
143+
substitute_values=substitute_values,
144+
substitute_keys=substitute_keys,
145+
)
146+
for key, value in dictionary.items()
147+
}
148+
149+
127150
def load_model(
128151
module_class: Type[pl.LightningModule],
129152
checkpoint_path: Path,
130153
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
154+
substitute_keys: Optional[Dict[str, str]] = None,
155+
substitute_values: Optional[Dict[str, str]] = None,
131156
):
132157
# Lightning checkpoints end with .ckpt, ours with .ckpt.zip
133158
if checkpoint_path.name.endswith(".ckpt.zip"):
134159
checkpoint = NNCheckpointIO.load(path=checkpoint_path, map_location=map_location)
160+
161+
if substitute_values is not None:
162+
checkpoint = _substistute(checkpoint, substitute_values=substitute_values, substitute_keys=substitute_keys)
163+
135164
return module_class._load_model_state(checkpoint=checkpoint, metadata=checkpoint.get("metadata", None))
136165
else:
137166
pylogger.warning(f"Loading a legacy checkpoint (from vanilla PyTorch Lightning): '{checkpoint_path}'")

0 commit comments

Comments
 (0)