You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The ``dist_checkpointing.load_common_state_dict`` function is an entry point that allows loading only the “common” part of the checkpoints.
331
331
Most of the checkpoint config and metadata can be loaded with this method, which allows skipping data loading in order to take decisions regarding checkpoint config, version, etc.
332
332
333
-
dist_checkpointing.load_tensors_metadata
333
+
dist_checkpointing.load_sharded_metadata
334
334
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
335
-
The ``dist_checkpointing.load_tensors_metadata`` function is an entry point that allows reading all ShardedTensors metadata from the checkpoint without loading any data.
335
+
The ``dist_checkpointing.load_sharded_metadata`` function is an entry point that allows reading all ShardedTensors metadata from the checkpoint without loading any data.
336
336
The result is a sharded state dict with trivial sharding (every tensor is sharded into one big shard).
337
337
338
338
dist_checkpointing.load_plain_tensors
339
339
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
340
340
The ``dist_checkpointing.load_plain_tensors`` function is an entry point that allows reading sharded tensors stored in the checkpoint without any sharding (as plain tensors).
341
341
This function is simply a composition of ``load_tensors_metadata`` and ``save``.
342
342
343
+
dist_checkpointing.load_content_metadata
344
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
345
+
The ``dist_checkpointing.load_content_metadata`` function is an entry point that allows reading content versioning metadata saved during `save`.
346
+
See `Checkpoint versioning`_ for more details.
347
+
343
348
Save and Load Strategies
344
349
------------------------
345
350
There are multiple ways to save a sharded state dict into a serialized checkpoint. They can be provided by the user as saving and loading strategies (e.g. ``TorchDistLoadShardedStrategy`` and ``TorchDistSaveShardedStrategy`` as shown below).
@@ -530,6 +535,50 @@ and using the ``dist_checkpointing.save`` and ``dist_checkpointing.load`` entryp
530
535
In Megatron Core, the sharded state dictionary preparation is already implemented in a ``sharded_state_dict`` method which creates the sharded state dicts in a composable way.
531
536
For other applications (e.g. with simpler types of supported parallelisms) it might be possible to apply a straightforward conversion from a regular model state dict into a sharded state dict.
532
537
538
+
Checkpoint versioning
539
+
^^^^^^^^^^^^^^^^^^^^^
540
+
Megatron-Core v0.14 exposes ``content_metadata`` flag for the ``save`` routine which allows to store metadata describing the checkpoint content (and a corresponding `load_content_metadata` function for loading).
541
+
In particular, this is the intended place to store application specific versioning information - ``dist_checkpointing`` doesn't interpret the metadata at any point.
542
+
The idea behind this feature is to provide a way to access content identifying metadata without reading the whole checkpoint.
543
+
Since loading a distributed checkpoint requires providing valid ShardedTensors to the ``load`` routine, in some cases it can be impossible
544
+
to load the tensors from the checkpoint without using the content version to prepare the correct sharded state dict in advance.
545
+
546
+
In Megatron-LM and NeMo frameworks, the whole content metadata is passed to ``shared_state_dict`` model and optimizer methods
547
+
and therefore affects only the logic behind sharded_state_dict creation.
548
+
The recommended versioning practice for those frameworks is to use content metadata only for ``sharded_state_dict`` behavior control,
549
+
e.g. avoid storing metadata which affects framework logic in other way.
550
+
The content metadata should be minimalistic (to avoid a bloated metadata with multiple possible configurations),
551
+
ideally flat (or with a single nesting level) and with semantically meaningful flag names (e.g. ``distrib_optim_sharding_type`` or ``non_homogeneous_layers``).
552
+
In particular, a simple integer (or SemVer) versioning flag (e.g. ``metadata['version'] = 3.4``) is discouraged,
553
+
because the metadata serves for all models and optimizers and it's practically impossible to enforce a linearly increasing versioning for this whole space.
554
+
555
+
In NeMo or Megatron-LM the versioning logic (calling ``sharded_state_dict`` method with appropriate metadata) is already implemented.
556
+
In order to introduce a new checkpoint version, two steps are required:
557
+
558
+
1. Add some new flag to the metadata which is passed to ``sharded_state_dict`` methods by the framework (e.g. ``metadata['model_X_layout_Y'] = True``).
559
+
E.g. in NeMo the metadata is determined in the ``MegatronStrategy.sharded_state_dict_metadata`` property.
560
+
561
+
1. Handle the new flag in the appropriate ``sharded_state_dict`` method (in Megatron-Core or framework or user code).
562
+
**Make sure to keep the old logic in case the new flag is absent. This will ensure both the new and old checkpoints can be loaded correctly**.
563
+
This logic must be kept until the old checkpoint version is deprecated. Similarly with metadata flag removal. For example:
if (metadata or {}).get('model_X_layout_Y', False):
569
+
# new behavior
570
+
else:
571
+
# old behavior
572
+
if (metadata or {}).get('already_removed_flag', False):
573
+
# old behavior (!)
574
+
else:
575
+
# new behavior
576
+
577
+
Note: Currently the content metadata is part of the "common" checkpoint state (and in consequence resides in ``common.pt`` file) but this is an implementation
578
+
detail and could be changed in the future. Therefore it's recommended to save/load the content metadata with the API described at the beginning of this section.
579
+
580
+
Note: currently in NeMo and Megatron-LM versioning content is stored only in global checkpoints. For local checkpoints,
581
+
it is assumed that save and load content version are the same and thus `sharded_state_dict` uses runtime metadata in both cases.
533
582
534
583
FAQs
535
584
-----------------------
@@ -574,6 +623,14 @@ FAQs
574
623
575
624
To accelerate checkpoint saving, it is recommended to set ``dist_ckpt_assume_constant_structure=True``.
576
625
626
+
**9. Q: I get an error about an "invalid access pattern". What does it mean?**
627
+
628
+
A: The logs print the access pattern tensor count. Its shape corresponds to the ShardedTensor sharding grid
629
+
(e.g. 3-dimensional parameter sharded by TP along the 1st axis would have the access pattern tensor of shape ``(1, TP size, 1)``).
630
+
The tensor values correspond to the number of ShardedTensors with main ``replica_id`` corresponding to that shard.
631
+
A correct shared_state_dict definition results in an access pattern with 1s in each cell. Invalid access pattern usually
632
+
means an incorrect ShardedTensor sharding defined in the ``sharded_state_dict`` model method.
0 commit comments