Skip to content

Commit 77026c3

Browse files
authored
remove ckpt_save_pre_mcore_014 support (#15146)
* remove ckpt_save_pre_mcore_014 param Signed-off-by: dimapihtar <[email protected]> * remove imports Signed-off-by: dimapihtar <[email protected]> --------- Signed-off-by: dimapihtar <[email protected]>
1 parent 5651c2f commit 77026c3

File tree

2 files changed

+14
-52
lines changed

2 files changed

+14
-52
lines changed

nemo/lightning/pytorch/strategies/megatron_strategy.py

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
from megatron.core.dist_checkpointing.validation import StrictHandling
5959
from megatron.core.distributed import DistributedDataParallelConfig
6060
from megatron.core.optimizer import OptimizerConfig
61-
from megatron.core.utils import get_torch_version, is_torch_min_version
6261

6362
HAVE_MEGATRON_CORE = True
6463
except (ImportError, ModuleNotFoundError):
@@ -216,11 +215,6 @@ class MegatronStrategy(DDPStrategy, io.IOMixin):
216215
If not None, overwrites the `strict` flag passed to `load_checkpoint`.
217216
Defaults to None. For a list of supported values, refer to the Megatron Core documentation:
218217
https://github.com/NVIDIA/Megatron-LM/blob/d4e72c0d33edc0c53aeb624f617eb77cebce6ae9/megatron/core/dist_checkpointing/validation.py#L46
219-
ckpt_save_pre_mcore_014 (bool, optional): if True, brings back sharded state dict definition from
220-
before Megatron-Core v0.14 versions for checkpoint saving. It doesn't affect loading as the
221-
loading format is determined based on metadata stored in the checkpoint. This flag is provided
222-
temporarily as a fallback to previous behavior in case of unexpected issues with the new formats.
223-
Defaults to False.
224218
ckpt_optim_fully_reshardable (bool, optional): switches to a fully reshardable (TP/PP/DP/EP)
225219
optimizer format. Defaults to False, in which case a DP-only reshardable format is used.
226220
distrib_optim_fully_reshardable_mem_efficient (bool, optional): minimizes CUDA and host memory
@@ -301,7 +295,6 @@ def __init__(
301295
ckpt_parallel_save_optim: Optional[bool] = None,
302296
ckpt_load_directly_on_device: bool = True,
303297
ckpt_load_strictness: Optional['StrictHandling'] = None,
304-
ckpt_save_pre_mcore_014: bool = False,
305298
ckpt_optim_fully_reshardable: bool = False,
306299
distrib_optim_fully_reshardable_mem_efficient: bool = False,
307300
setup_optimizers: bool = True,
@@ -352,7 +345,6 @@ def __init__(
352345
self.ckpt_save_optimizer = ckpt_save_optimizer
353346
self.ckpt_load_main_params = ckpt_load_main_params
354347
self.ckpt_load_strictness = ckpt_load_strictness
355-
self.ckpt_save_pre_mcore_014 = ckpt_save_pre_mcore_014
356348
self.ckpt_optim_fully_reshardable = ckpt_optim_fully_reshardable
357349
self.distrib_optim_fully_reshardable_mem_efficient = distrib_optim_fully_reshardable_mem_efficient
358350
self.use_te_rng_tracker = use_te_rng_tracker
@@ -442,11 +434,10 @@ def __init__(
442434
if self.ckpt_load_optimizer and self.ckpt_load_main_params:
443435
raise ValueError("ckpt_load_optimizer and ckpt_load_main_params cannot be both set to True.")
444436

445-
if self.parallel_save_optim is not None and not self.ckpt_save_pre_mcore_014:
437+
if self.parallel_save_optim is not None:
446438
logging.warning(
447439
"`ckpt_parallel_save_optim` argument is replaced with"
448440
" `ckpt_optim_fully_reshardable` and does not have any effect"
449-
" (unless used together with `ckpt_save_pre_mcore_014=True`)"
450441
)
451442

452443
if isinstance(self.ddp_config, DistributedDataParallelConfig):
@@ -1228,28 +1219,14 @@ def sharded_state_dict_metadata(self):
12281219
if use_distributed_optimizer and use_megatron_fsdp:
12291220
metadata["distrib_optim_sharding_type"] = "fsdp_dtensor"
12301221

1231-
force_pre_mcore_014 = not is_torch_min_version("2.6a0")
1232-
if force_pre_mcore_014:
1233-
logging.warning(
1234-
f"PyTorch version {get_torch_version()} below 2.6 detected."
1235-
f" Forcing ckpt_save_pre_mcore_014 behavior."
1236-
)
1237-
1238-
if self.ckpt_save_pre_mcore_014 or force_pre_mcore_014:
1239-
if use_distributed_optimizer and not use_megatron_fsdp:
1240-
if self.parallel_save_optim:
1241-
metadata["distrib_optim_sharding_type"] = "fully_sharded_model_space"
1242-
else:
1243-
metadata["distrib_optim_sharding_type"] = "dp_zero_gather_scatter"
1244-
else:
1245-
if use_distributed_optimizer and not use_megatron_fsdp:
1246-
if self.ckpt_optim_fully_reshardable:
1247-
metadata['distrib_optim_sharding_type'] = 'fully_reshardable'
1248-
metadata['distrib_optim_fully_reshardable_mem_efficient'] = (
1249-
self.distrib_optim_fully_reshardable_mem_efficient
1250-
)
1251-
else:
1252-
metadata['distrib_optim_sharding_type'] = 'dp_reshardable'
1222+
if use_distributed_optimizer and not use_megatron_fsdp:
1223+
if self.ckpt_optim_fully_reshardable:
1224+
metadata['distrib_optim_sharding_type'] = 'fully_reshardable'
1225+
metadata['distrib_optim_fully_reshardable_mem_efficient'] = (
1226+
self.distrib_optim_fully_reshardable_mem_efficient
1227+
)
1228+
else:
1229+
metadata['distrib_optim_sharding_type'] = 'dp_reshardable'
12531230
return metadata
12541231

12551232
def selective_restore(self) -> None:

tests/lightning/pytorch/strategies/test_megatron_strategy.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,18 @@
2222

2323

2424
def get_metadata(
25-
ckpt_save_pre_mcore_014: bool = None,
2625
ckpt_parallel_save_optim: bool = None,
2726
ckpt_optim_fully_reshardable: bool = None,
2827
) -> dict:
2928
metadata = {
3029
'singleton_local_shards': False,
3130
'chained_optim_avoid_prefix': True,
3231
}
33-
if ckpt_save_pre_mcore_014:
34-
if ckpt_parallel_save_optim:
35-
metadata['distrib_optim_sharding_type'] = 'fully_sharded_model_space'
36-
else:
37-
metadata['distrib_optim_sharding_type'] = 'dp_zero_gather_scatter'
32+
if ckpt_optim_fully_reshardable:
33+
metadata['distrib_optim_sharding_type'] = 'fully_reshardable'
34+
metadata['distrib_optim_fully_reshardable_mem_efficient'] = False
3835
else:
39-
if ckpt_optim_fully_reshardable:
40-
metadata['distrib_optim_sharding_type'] = 'fully_reshardable'
41-
metadata['distrib_optim_fully_reshardable_mem_efficient'] = False
42-
else:
43-
metadata['distrib_optim_sharding_type'] = 'dp_reshardable'
36+
metadata['distrib_optim_sharding_type'] = 'dp_reshardable'
4437

4538
return metadata
4639

@@ -95,18 +88,10 @@ def test_ckpt_load_main_params_without_state_dict(self):
9588
strategy.optimizers[0].reload_model_params.assert_called_once_with(checkpoint)
9689

9790
def test_sharded_state_dict_metadata(self):
98-
strategy = MegatronStrategy(ckpt_save_pre_mcore_014=False, ckpt_parallel_save_optim=True)
91+
strategy = MegatronStrategy(ckpt_parallel_save_optim=True)
9992

10093
ddp = DistributedDataParallelConfig(use_distributed_optimizer=True)
10194

102-
strategy = MegatronStrategy(ckpt_save_pre_mcore_014=True, ckpt_parallel_save_optim=True, ddp=ddp)
103-
metadata = strategy.sharded_state_dict_metadata
104-
assert metadata == get_metadata(ckpt_save_pre_mcore_014=True, ckpt_parallel_save_optim=True)
105-
106-
strategy = MegatronStrategy(ckpt_save_pre_mcore_014=True, ddp=ddp)
107-
metadata = strategy.sharded_state_dict_metadata
108-
assert metadata == get_metadata(ckpt_save_pre_mcore_014=True)
109-
11095
strategy = MegatronStrategy(ckpt_optim_fully_reshardable=True, ddp=ddp)
11196
metadata = strategy.sharded_state_dict_metadata
11297
assert metadata == get_metadata(ckpt_optim_fully_reshardable=True)

0 commit comments

Comments
 (0)