|
58 | 58 | from megatron.core.dist_checkpointing.validation import StrictHandling |
59 | 59 | from megatron.core.distributed import DistributedDataParallelConfig |
60 | 60 | from megatron.core.optimizer import OptimizerConfig |
61 | | - from megatron.core.utils import get_torch_version, is_torch_min_version |
62 | 61 |
|
63 | 62 | HAVE_MEGATRON_CORE = True |
64 | 63 | except (ImportError, ModuleNotFoundError): |
@@ -216,11 +215,6 @@ class MegatronStrategy(DDPStrategy, io.IOMixin): |
216 | 215 | If not None, overwrites the `strict` flag passed to `load_checkpoint`. |
217 | 216 | Defaults to None. For a list of supported values, refer to the Megatron Core documentation: |
218 | 217 | https://github.com/NVIDIA/Megatron-LM/blob/d4e72c0d33edc0c53aeb624f617eb77cebce6ae9/megatron/core/dist_checkpointing/validation.py#L46 |
219 | | - ckpt_save_pre_mcore_014 (bool, optional): if True, brings back sharded state dict definition from |
220 | | - before Megatron-Core v0.14 versions for checkpoint saving. It doesn't affect loading as the |
221 | | - loading format is determined based on metadata stored in the checkpoint. This flag is provided |
222 | | - temporarily as a fallback to previous behavior in case of unexpected issues with the new formats. |
223 | | - Defaults to False. |
224 | 218 | ckpt_optim_fully_reshardable (bool, optional): switches to a fully reshardable (TP/PP/DP/EP) |
225 | 219 | optimizer format. Defaults to False, in which case a DP-only reshardable format is used. |
226 | 220 | distrib_optim_fully_reshardable_mem_efficient (bool, optional): minimizes CUDA and host memory |
@@ -301,7 +295,6 @@ def __init__( |
301 | 295 | ckpt_parallel_save_optim: Optional[bool] = None, |
302 | 296 | ckpt_load_directly_on_device: bool = True, |
303 | 297 | ckpt_load_strictness: Optional['StrictHandling'] = None, |
304 | | - ckpt_save_pre_mcore_014: bool = False, |
305 | 298 | ckpt_optim_fully_reshardable: bool = False, |
306 | 299 | distrib_optim_fully_reshardable_mem_efficient: bool = False, |
307 | 300 | setup_optimizers: bool = True, |
@@ -352,7 +345,6 @@ def __init__( |
352 | 345 | self.ckpt_save_optimizer = ckpt_save_optimizer |
353 | 346 | self.ckpt_load_main_params = ckpt_load_main_params |
354 | 347 | self.ckpt_load_strictness = ckpt_load_strictness |
355 | | - self.ckpt_save_pre_mcore_014 = ckpt_save_pre_mcore_014 |
356 | 348 | self.ckpt_optim_fully_reshardable = ckpt_optim_fully_reshardable |
357 | 349 | self.distrib_optim_fully_reshardable_mem_efficient = distrib_optim_fully_reshardable_mem_efficient |
358 | 350 | self.use_te_rng_tracker = use_te_rng_tracker |
@@ -442,11 +434,10 @@ def __init__( |
442 | 434 | if self.ckpt_load_optimizer and self.ckpt_load_main_params: |
443 | 435 | raise ValueError("ckpt_load_optimizer and ckpt_load_main_params cannot be both set to True.") |
444 | 436 |
|
445 | | - if self.parallel_save_optim is not None and not self.ckpt_save_pre_mcore_014: |
| 437 | + if self.parallel_save_optim is not None: |
446 | 438 | logging.warning( |
447 | 439 | "`ckpt_parallel_save_optim` argument is replaced with" |
448 | 440 | " `ckpt_optim_fully_reshardable` and does not have any effect" |
449 | | - " (unless used together with `ckpt_save_pre_mcore_014=True`)" |
450 | 441 | ) |
451 | 442 |
|
452 | 443 | if isinstance(self.ddp_config, DistributedDataParallelConfig): |
@@ -1228,28 +1219,14 @@ def sharded_state_dict_metadata(self): |
1228 | 1219 | if use_distributed_optimizer and use_megatron_fsdp: |
1229 | 1220 | metadata["distrib_optim_sharding_type"] = "fsdp_dtensor" |
1230 | 1221 |
|
1231 | | - force_pre_mcore_014 = not is_torch_min_version("2.6a0") |
1232 | | - if force_pre_mcore_014: |
1233 | | - logging.warning( |
1234 | | - f"PyTorch version {get_torch_version()} below 2.6 detected." |
1235 | | - f" Forcing ckpt_save_pre_mcore_014 behavior." |
1236 | | - ) |
1237 | | - |
1238 | | - if self.ckpt_save_pre_mcore_014 or force_pre_mcore_014: |
1239 | | - if use_distributed_optimizer and not use_megatron_fsdp: |
1240 | | - if self.parallel_save_optim: |
1241 | | - metadata["distrib_optim_sharding_type"] = "fully_sharded_model_space" |
1242 | | - else: |
1243 | | - metadata["distrib_optim_sharding_type"] = "dp_zero_gather_scatter" |
1244 | | - else: |
1245 | | - if use_distributed_optimizer and not use_megatron_fsdp: |
1246 | | - if self.ckpt_optim_fully_reshardable: |
1247 | | - metadata['distrib_optim_sharding_type'] = 'fully_reshardable' |
1248 | | - metadata['distrib_optim_fully_reshardable_mem_efficient'] = ( |
1249 | | - self.distrib_optim_fully_reshardable_mem_efficient |
1250 | | - ) |
1251 | | - else: |
1252 | | - metadata['distrib_optim_sharding_type'] = 'dp_reshardable' |
| 1222 | + if use_distributed_optimizer and not use_megatron_fsdp: |
| 1223 | + if self.ckpt_optim_fully_reshardable: |
| 1224 | + metadata['distrib_optim_sharding_type'] = 'fully_reshardable' |
| 1225 | + metadata['distrib_optim_fully_reshardable_mem_efficient'] = ( |
| 1226 | + self.distrib_optim_fully_reshardable_mem_efficient |
| 1227 | + ) |
| 1228 | + else: |
| 1229 | + metadata['distrib_optim_sharding_type'] = 'dp_reshardable' |
1253 | 1230 | return metadata |
1254 | 1231 |
|
1255 | 1232 | def selective_restore(self) -> None: |
|
0 commit comments