Skip to content

Commit e295dbc

Browse files
mikolajblazchtruong814dimapihtar
authored
Sketch dist-ckpt content versioning (#13839)
* Sketch dist-ckpt content versioning Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> # Conflicts: # nemo/lightning/pytorch/strategies/megatron_strategy.py * Apply isort and black reformatting Signed-off-by: mikolajblaz <mikolajblaz@users.noreply.github.com> Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Add docs Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Change dist_opt_sharding_type name Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Remove MappingProxyType Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Apply isort and black reformatting Signed-off-by: mikolajblaz <mikolajblaz@users.noreply.github.com> Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Expand docs Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Fix .rst formatting Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Fix .rst formatting Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Fix .rst formatting Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Fix .rst formatting Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Unindent code Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Match MLM metadata creation logic Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Apply isort and black reformatting Signed-off-by: mikolajblaz <mikolajblaz@users.noreply.github.com> Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Remove one Nemo1 TODO Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Add doc Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Remove collections/nlp TODOs (NeMo 1) Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Fix some TODOs Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Remove f string Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Fix MegatronStrategy typo Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Handle TODOs Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Add missing import Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Add content metadata flag to async Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Add missing load_content_metadata Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Load content_metadata through unwrapped_ckpt_io Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Pass content_metadata through storage_options Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Fix linting problems Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Add sharded_state_dict_metadata to FabricMegatronStrategy Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Fix indentation Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Fix last unwrapped_checkpoint_io Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Fix None type Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Add local ckpt versioning note Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Add safe_import fix Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Revert nemo_logger changes Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * Remove chained_optim_avoid_prefix flag Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> * add unit tests Signed-off-by: dimapihtar <dpihtar@gmail.com> * Apply isort and black reformatting Signed-off-by: dimapihtar <dimapihtar@users.noreply.github.com> * Fix OptimizerWrapper init signature Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> --------- Signed-off-by: mikolajblaz <mikolajblaz@users.noreply.github.com> Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com> Signed-off-by: dimapihtar <dpihtar@gmail.com> Signed-off-by: dimapihtar <dimapihtar@users.noreply.github.com> Co-authored-by: Charlie Truong <chtruong@nvidia.com> Co-authored-by: dimapihtar <dpihtar@gmail.com> Co-authored-by: dimapihtar <dimapihtar@users.noreply.github.com> Co-authored-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com>
1 parent b56af44 commit e295dbc

File tree

19 files changed

+403
-63
lines changed

19 files changed

+403
-63
lines changed

docs/source/checkpoints/dist_ckpt.rst

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,16 +330,21 @@ dist_checkpointing.load_common_state_dict
330330
The ``dist_checkpointing.load_common_state_dict`` function is an entry point that allows loading only the “common” part of the checkpoints.
331331
Most of the checkpoint config and metadata can be loaded with this method, which allows skipping data loading in order to take decisions regarding checkpoint config, version, etc.
332332

333-
dist_checkpointing.load_tensors_metadata
333+
dist_checkpointing.load_sharded_metadata
334334
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
335-
The ``dist_checkpointing.load_tensors_metadata`` function is an entry point that allows reading all ShardedTensors metadata from the checkpoint without loading any data.
335+
The ``dist_checkpointing.load_sharded_metadata`` function is an entry point that allows reading all ShardedTensors metadata from the checkpoint without loading any data.
336336
The result is a sharded state dict with trivial sharding (every tensor is sharded into one big shard).
337337

338338
dist_checkpointing.load_plain_tensors
339339
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
340340
The ``dist_checkpointing.load_plain_tensors`` function is an entry point that allows reading sharded tensors stored in the checkpoint without any sharding (as plain tensors).
341341
This function is simply a composition of ``load_tensors_metadata`` and ``save``.
342342

343+
dist_checkpointing.load_content_metadata
344+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
345+
The ``dist_checkpointing.load_content_metadata`` function is an entry point that allows reading content versioning metadata saved during `save`.
346+
See `Checkpoint versioning`_ for more details.
347+
343348
Save and Load Strategies
344349
------------------------
345350
There are multiple ways to save a sharded state dict into a serialized checkpoint. They can be provided by the user as saving and loading strategies (e.g. ``TorchDistLoadShardedStrategy`` and ``TorchDistSaveShardedStrategy`` as shown below).
@@ -530,6 +535,50 @@ and using the ``dist_checkpointing.save`` and ``dist_checkpointing.load`` entryp
530535
In Megatron Core, the sharded state dictionary preparation is already implemented in a ``sharded_state_dict`` method which creates the sharded state dicts in a composable way.
531536
For other applications (e.g. with simpler types of supported parallelisms) it might be possible to apply a straightforward conversion from a regular model state dict into a sharded state dict.
532537

538+
Checkpoint versioning
539+
^^^^^^^^^^^^^^^^^^^^^
540+
Megatron-Core v0.14 exposes ``content_metadata`` flag for the ``save`` routine which allows to store metadata describing the checkpoint content (and a corresponding `load_content_metadata` function for loading).
541+
In particular, this is the intended place to store application specific versioning information - ``dist_checkpointing`` doesn't interpret the metadata at any point.
542+
The idea behind this feature is to provide a way to access content identifying metadata without reading the whole checkpoint.
543+
Since loading a distributed checkpoint requires providing valid ShardedTensors to the ``load`` routine, in some cases it can be impossible
544+
to load the tensors from the checkpoint without using the content version to prepare the correct sharded state dict in advance.
545+
546+
In Megatron-LM and NeMo frameworks, the whole content metadata is passed to ``shared_state_dict`` model and optimizer methods
547+
and therefore affects only the logic behind sharded_state_dict creation.
548+
The recommended versioning practice for those frameworks is to use content metadata only for ``sharded_state_dict`` behavior control,
549+
e.g. avoid storing metadata which affects framework logic in other way.
550+
The content metadata should be minimalistic (to avoid a bloated metadata with multiple possible configurations),
551+
ideally flat (or with a single nesting level) and with semantically meaningful flag names (e.g. ``distrib_optim_sharding_type`` or ``non_homogeneous_layers``).
552+
In particular, a simple integer (or SemVer) versioning flag (e.g. ``metadata['version'] = 3.4``) is discouraged,
553+
because the metadata serves for all models and optimizers and it's practically impossible to enforce a linearly increasing versioning for this whole space.
554+
555+
In NeMo or Megatron-LM the versioning logic (calling ``sharded_state_dict`` method with appropriate metadata) is already implemented.
556+
In order to introduce a new checkpoint version, two steps are required:
557+
558+
1. Add some new flag to the metadata which is passed to ``sharded_state_dict`` methods by the framework (e.g. ``metadata['model_X_layout_Y'] = True``).
559+
E.g. in NeMo the metadata is determined in the ``MegatronStrategy.sharded_state_dict_metadata`` property.
560+
561+
1. Handle the new flag in the appropriate ``sharded_state_dict`` method (in Megatron-Core or framework or user code).
562+
**Make sure to keep the old logic in case the new flag is absent. This will ensure both the new and old checkpoints can be loaded correctly**.
563+
This logic must be kept until the old checkpoint version is deprecated. Similarly with metadata flag removal. For example:
564+
565+
.. code-block:: python
566+
567+
def sharded_state_dict(..., metadata: Optional[dict] = None):
568+
if (metadata or {}).get('model_X_layout_Y', False):
569+
# new behavior
570+
else:
571+
# old behavior
572+
if (metadata or {}).get('already_removed_flag', False):
573+
# old behavior (!)
574+
else:
575+
# new behavior
576+
577+
Note: Currently the content metadata is part of the "common" checkpoint state (and in consequence resides in ``common.pt`` file) but this is an implementation
578+
detail and could be changed in the future. Therefore it's recommended to save/load the content metadata with the API described at the beginning of this section.
579+
580+
Note: currently in NeMo and Megatron-LM versioning content is stored only in global checkpoints. For local checkpoints,
581+
it is assumed that save and load content version are the same and thus `sharded_state_dict` uses runtime metadata in both cases.
533582

534583
FAQs
535584
-----------------------
@@ -574,6 +623,14 @@ FAQs
574623

575624
To accelerate checkpoint saving, it is recommended to set ``dist_ckpt_assume_constant_structure=True``.
576625

626+
**9. Q: I get an error about an "invalid access pattern". What does it mean?**
627+
628+
A: The logs print the access pattern tensor count. Its shape corresponds to the ShardedTensor sharding grid
629+
(e.g. 3-dimensional parameter sharded by TP along the 1st axis would have the access pattern tensor of shape ``(1, TP size, 1)``).
630+
The tensor values correspond to the number of ShardedTensors with main ``replica_id`` corresponding to that shard.
631+
A correct shared_state_dict definition results in an access pattern with 1s in each cell. Invalid access pattern usually
632+
means an incorrect ShardedTensor sharding defined in the ``sharded_state_dict`` model method.
633+
577634

578635
Glossary
579636
-----------------------

nemo/collections/diffusion/models/flux/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,10 @@ def load_from_pretrained(
370370
if load_dist_ckpt:
371371
from megatron.core import dist_checkpointing
372372

373-
sharded_state_dict = dict(state_dict=self.sharded_state_dict(prefix="module."))
373+
sharded_sd_metadata = dist_checkpointing.load_content_metadata(ckpt_path)
374+
sharded_state_dict = dict(
375+
state_dict=self.sharded_state_dict(prefix="module.", metadata=sharded_sd_metadata)
376+
)
374377
loaded_state_dict = dist_checkpointing.load(
375378
sharded_state_dict=sharded_state_dict, checkpoint_dir=ckpt_path
376379
)

nemo/collections/llm/inference/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ def _setup_trainer_and_restore_model(
180180
peft: Optional[PEFT] = model.model_transform
181181
if isinstance(peft, PEFT):
182182
model = peft(model)
183-
sharded_state_dict = MegatronModule.sharded_state_dict(model)
183+
sharded_sd_metadata = trainer.strategy.unwrapped_checkpoint_io.load_content_metadata(path)
184+
sharded_state_dict = MegatronModule.sharded_state_dict(model, metadata=sharded_sd_metadata)
184185
adapter_sharded_state_dict = {k: v for k, v in sharded_state_dict.items() if ".adapter." in k}
185186
adapter_state = trainer.strategy.checkpoint_io.load_checkpoint(
186187
ckpt_to_weights_subdir(path, is_saving=False), sharded_state_dict=adapter_sharded_state_dict

nemo/collections/llm/modelopt/distill/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ def teacher_provider(
145145
# TODO(aanoosheh): Replace spec with modelopt one
146146
model = config.configure_model(tokenizer)
147147

148-
sharded_state_dict = {"state_dict": model.sharded_state_dict(prefix="module.")}
148+
sharded_sd_metadata = trainer.strategy.unwrapped_checkpoint_io.load_content_metadata(ckpt_path)
149+
sharded_state_dict = {"state_dict": model.sharded_state_dict(prefix="module.", metadata=sharded_sd_metadata)}
149150
strict = trainer.strategy.ckpt_load_strictness
150151
checkpoint = trainer.strategy.checkpoint_io.load_checkpoint(ckpt_path, sharded_state_dict, strict=strict)
151152
state_dict = {k.replace("module.", ""): v for k, v in checkpoint["state_dict"].items()}

nemo/collections/llm/modelopt/prune/pruner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,11 @@ def save_pruned_model(trainer: nl.Trainer, save_path: str) -> None:
130130
# TODO: trainer.save_checkpoint(save_path) doesnt seem to save metadata.json or .metadata files!
131131
weight_path = ckpt_to_weights_subdir(save_path, is_saving=True)
132132
weight_path.mkdir(parents=True, exist_ok=True)
133-
dist_checkpointing.save(trainer.strategy.megatron_parallel.sharded_state_dict(), weight_path)
133+
dist_checkpointing.save(
134+
trainer.strategy.megatron_parallel.sharded_state_dict(),
135+
weight_path,
136+
content_metadata=trainer.strategy.sharded_state_dict_metadata,
137+
)
134138

135139
if is_global_rank_zero():
136140
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_context_subdir(save_path), yaml_attrs=["model"])

nemo/collections/llm/peft/api.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,19 +164,27 @@ def _setup_trainer_and_restore_model_and_adapter(
164164
model.trainer = trainer
165165

166166
lora(model)
167+
weights_dir = ckpt_to_weights_subdir(lora_checkpoint_path, is_saving=False)
168+
sharded_sd_metadata = trainer.strategy.unwrapped_checkpoint_io.load_content_metadata(weights_dir)
167169
adapter_sharded_state_dict = {
168-
k: v for k, v in trainer.strategy.megatron_parallel.sharded_state_dict().items() if ".adapter." in k
170+
k: v
171+
for k, v in trainer.strategy.megatron_parallel.sharded_state_dict(metadata=sharded_sd_metadata).items()
172+
if ".adapter." in k
169173
}
170174
adapter_state = trainer.strategy.checkpoint_io.load_checkpoint(
171-
ckpt_to_weights_subdir(lora_checkpoint_path, is_saving=False), sharded_state_dict=adapter_sharded_state_dict
175+
weights_dir, sharded_state_dict=adapter_sharded_state_dict
172176
)
173177
trainer.strategy.load_model_state_dict(adapter_state, strict=False)
174178

175179

176180
def _save_merged_weight(output_path: str, merged_weights: dict, model: pl.LightningModule, trainer: Trainer):
177181
weight_path = ckpt_to_weights_subdir(output_path, is_saving=True)
178182
Path(weight_path).mkdir(parents=True, exist_ok=True)
179-
dist_checkpointing.save(merged_weights, str(ckpt_to_weights_subdir(output_path, is_saving=True)))
183+
dist_checkpointing.save(
184+
merged_weights,
185+
str(ckpt_to_weights_subdir(output_path, is_saving=True)),
186+
content_metadata=trainer.strategy.sharded_state_dict_metadata,
187+
)
180188
if hasattr(model.tokenizer, "save_pretrained"):
181189
model.tokenizer.save_pretrained("/tmp/nemo_tokenizer")
182190
model.tokenizer = AutoTokenizer("/tmp/nemo_tokenizer")

nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2048,7 +2048,6 @@ def sharded_state_dict(self, prefix: str = '') -> Dict[str, Any]:
20482048
self.state_dict().
20492049
The sharded tensor mapping is defined in the GPTModel class from mcore.
20502050
"""
2051-
20522051
if self.mcore_gpt:
20532052
module_prefix = f'{prefix}model.'
20542053
sharded_state_dict = {}

nemo/collections/speechlm/models/speech_to_text_llm_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,11 +218,13 @@ def _maybe_load_pretrained_llm(self, model: MCoreGPTModel, strict: bool = False)
218218
llm_model_cls(self.language_model_config), f"{self.language_model_hub}{ckpt_path}", on_import_ckpt=False
219219
)
220220

221-
sharded_state_dict = dict(state_dict=model.sharded_state_dict(prefix="module."))
221+
load_path = ckpt_to_weights_subdir(ckpt_path, is_saving=False)
222+
sharded_sd_metadata = dist_checkpointing.load_content_metadata(load_path)
223+
sharded_state_dict = dict(state_dict=model.sharded_state_dict(prefix="module.", metadata=sharded_sd_metadata))
222224

223225
loaded_state_dict = dist_checkpointing.load(
224226
sharded_state_dict=sharded_state_dict,
225-
checkpoint_dir=ckpt_to_weights_subdir(ckpt_path, is_saving=False),
227+
checkpoint_dir=load_path,
226228
validate_access_integrity=False,
227229
**({"strict": "log_all"} if not strict else {}),
228230
)

nemo/collections/vlm/neva/model/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,12 @@ def restore_model_weights(model, checkpoint_path, strict=False):
8383
strict: Whether to restore weights even if they are not the same.
8484
"""
8585
if checkpoint_path is not None:
86-
sharded_state_dict = dict(state_dict=model.sharded_state_dict(prefix="module."))
86+
weights_dir = ckpt_to_weights_subdir(checkpoint_path, is_saving=False)
87+
sharded_sd_metadata = dist_checkpointing.load_content_metadata(weights_dir)
88+
sharded_state_dict = dict(state_dict=model.sharded_state_dict(prefix="module.", metadata=sharded_sd_metadata))
8789
loaded_state_dict = dist_checkpointing.load(
8890
sharded_state_dict=sharded_state_dict,
89-
checkpoint_dir=ckpt_to_weights_subdir(checkpoint_path, is_saving=False),
91+
checkpoint_dir=weights_dir,
9092
validate_access_integrity=False,
9193
**({"strict": "log_all"} if not strict else {}),
9294
)

nemo/core/optim/mcore_optim.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch
1616

17+
from nemo.utils import logging
1718
from nemo.utils.nvtx import nvtx_range_pop, nvtx_range_push
1819

1920

@@ -94,7 +95,12 @@ def load_state_dict(self, state_dict):
9495
self.mcore_optimizer.load_state_dict(state_dict)
9596

9697
def sharded_state_dict(
97-
self, model_sharded_state_dict, optimizer_state_dict=None, is_loading=False, dist_ckpt_parallel_save=False
98+
self,
99+
model_sharded_state_dict,
100+
optimizer_state_dict=None,
101+
is_loading=False,
102+
dist_ckpt_parallel_save=None,
103+
**kwargs,
98104
):
99105
"""
100106
Returns the sharded state dictionary for distributed checkpointing.
@@ -109,10 +115,15 @@ def sharded_state_dict(
109115
Returns:
110116
dict: The sharded optimizer state dictionary.
111117
"""
112-
sharding_type = 'fully_sharded_model_space' if dist_ckpt_parallel_save else 'dp_zero_gather_scatter'
113-
return self.mcore_optimizer.sharded_state_dict(
114-
model_sharded_state_dict, is_loading=is_loading, sharding_type=sharding_type
115-
)
118+
if dist_ckpt_parallel_save is not None:
119+
logging.warning(
120+
"dist_ckpt_parallel_save is deprecated, please use `metadata['distrib_optim_sharding_type']`"
121+
" to specify DistributedOptimizer format details instead."
122+
)
123+
kwargs['sharding_type'] = (
124+
'fully_sharded_model_space' if dist_ckpt_parallel_save else 'dp_zero_gather_scatter'
125+
)
126+
return self.mcore_optimizer.sharded_state_dict(model_sharded_state_dict, is_loading=is_loading, **kwargs)
116127

117128
def step(self, closure=None):
118129
"""

0 commit comments

Comments
 (0)