Skip to content

Commit fa007c8

Browse files
cp add_fc_into_zcc to nlp (#11187)
1 parent c8f1414 commit fa007c8

File tree

2 files changed

+749
-96
lines changed

2 files changed

+749
-96
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 103 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -188,13 +188,17 @@
188188

189189
try:
190190
from .utils.zero_cost_checkpoint import (
191+
DistInfoCollectorValidator,
191192
NonZCCEMACallback,
192193
ZeroCostCheckpointCallback,
194+
ZeroCostCheckpointCallbackFcBased,
193195
ZeroCostCheckpointManager,
194-
get_fused_param_mappings,
196+
ZeroCostCheckpointWorker,
197+
ZeroCostCheckpointWorkerFcBased,
195198
)
196199
except (ImportError, ModuleNotFoundError):
197-
ZeroCostCheckpointManager, NonZCCEMACallback, get_fused_param_mappings = None, None, None
200+
ZeroCostCheckpointManager, NonZCCEMACallback = None, None
201+
198202
from .utils.helper import ( # nested_truncate,
199203
broadcast_dataset_rank0_model,
200204
broadcast_dp_optimizer,
@@ -796,6 +800,90 @@ def _wrap_model_and_load_sharded_checkpoint(self, resume_from_checkpoint):
796800
self._load_from_checkpoint(resume_from_checkpoint)
797801
return model
798802

803+
def _get_zcc_implementation_classes(self):
804+
"""Get appropriate ZCC implementation classes based on checkpoint format."""
805+
if self.args.save_checkpoint_format == "flex_checkpoint":
806+
return ZeroCostCheckpointCallbackFcBased, ZeroCostCheckpointWorkerFcBased
807+
return ZeroCostCheckpointCallback, ZeroCostCheckpointWorker
808+
809+
def _create_zcc_manager_instance(self, unwrapped_model, zcc_worker_class):
810+
"""Create ZCC manager instance with appropriate configuration."""
811+
if isinstance(self.model, PipelineLayer):
812+
pipeline_hooks_capacity = (
813+
unwrapped_model.forward_pipeline_parallel_hook_capacity
814+
+ unwrapped_model.backward_pipeline_parallel_hook_capacity
815+
)
816+
else:
817+
pipeline_hooks_capacity = self.args.gradient_accumulation_steps
818+
819+
return ZeroCostCheckpointManager(
820+
worker_num=self.args.zcc_workers_num,
821+
pipeline_hooks_capacity=pipeline_hooks_capacity,
822+
capacity_usage=self.args.zcc_pipeline_hooks_capacity_usage,
823+
use_expert_parallel=self.args.use_expert_parallel,
824+
ema_coef=self.args.zcc_save_ema_coef,
825+
zcc_worker_class=zcc_worker_class,
826+
)
827+
828+
def _register_pipeline_hooks(self, unwrapped_model):
829+
"""Register forward and backward pipeline hooks."""
830+
# Register forward hooks
831+
for i in range(unwrapped_model.forward_pipeline_parallel_hook_capacity):
832+
unwrapped_model.register_forward_pipeline_parallel_hook(
833+
location=i, hook=self.zcc_manager.zcc_pipeline_hook
834+
)
835+
836+
# Register backward hooks
837+
for i in range(unwrapped_model.backward_pipeline_parallel_hook_capacity):
838+
unwrapped_model.register_backward_pipeline_parallel_hook(
839+
location=i, hook=self.zcc_manager.zcc_pipeline_hook
840+
)
841+
842+
def _setup_zcc_callback(self, zcc_callback_class):
843+
"""Setup ZCC callback with required dependencies."""
844+
callback = zcc_callback_class(self.args, self.zcc_manager, self.runtime_timer, self.sharding_io)
845+
self.add_callback(callback)
846+
847+
def _handle_checkpoint_resume(self, resume_from_checkpoint):
848+
"""Handle resumption from previous checkpoint if provided."""
849+
if resume_from_checkpoint is None:
850+
return
851+
852+
ema_state_path = self._get_ema_state_path(resume_from_checkpoint)
853+
854+
if not os.path.exists(ema_state_path):
855+
logger.info(f"ZCC EMA state dict not found at: {ema_state_path}")
856+
return
857+
858+
# Validate distributed strategy compatibility
859+
should_load_ema = self._should_load_ema_state(resume_from_checkpoint, ema_state_path)
860+
861+
if should_load_ema:
862+
logger.info(f"Loading ZCC EMA state from: {ema_state_path}")
863+
self.zcc_manager.set_ema_state_dict(ema_state_path)
864+
865+
def _get_ema_state_path(self, checkpoint_path):
866+
"""Get the path to EMA state based on checkpoint format."""
867+
if self.args.save_checkpoint_format == "flex_checkpoint":
868+
return os.path.join(checkpoint_path, "ema_state", f"{dist.get_rank()}_0.distcp")
869+
else:
870+
optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
871+
return os.path.join(checkpoint_path, optimizer_name).replace("optimizer", "ema")
872+
873+
def _should_load_ema_state(self, checkpoint_path, ema_state_path):
874+
"""Determine if EMA state should be loaded based on configuration and compatibility."""
875+
if self.args.zcc_save_ema_coef is None:
876+
logger.info("EMA coefficient is None, skipping EMA state loading")
877+
return False
878+
879+
success, err_msg = DistInfoCollectorValidator(self.args, self.hcg).check_same_strategy(checkpoint_path)
880+
881+
if not success:
882+
logger.warning(f"Cannot load EMA state due to strategy mismatch: {err_msg}")
883+
return False
884+
885+
return True
886+
799887
def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None):
800888
"""
801889
Create zero cost checkpoint manager.
@@ -806,55 +894,22 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None):
806894
self.model, PretrainedModel
807895
), "model should be a PretrainedModel when using zero cost checkpoint"
808896
logger.info("Create zero cost checkpoint manager...")
897+
898+
zcc_callback_class, zcc_worker_class = self._get_zcc_implementation_classes()
899+
900+
# Create ZCC manager with appropriate configuration
901+
self.zcc_manager = self._create_zcc_manager_instance(unwrapped_model, zcc_worker_class)
902+
903+
# Register pipeline hooks if using pipeline parallelism
809904
if isinstance(self.model, PipelineLayer):
810-
pipeline_hooks_capacity = (
811-
unwrapped_model.forward_pipeline_parallel_hook_capacity
812-
+ unwrapped_model.backward_pipeline_parallel_hook_capacity
813-
)
814-
self.zcc_manager = ZeroCostCheckpointManager(
815-
worker_num=self.args.zcc_workers_num,
816-
pipeline_hooks_capacity=pipeline_hooks_capacity,
817-
capacity_usage=self.args.zcc_pipeline_hooks_capacity_usage,
818-
use_expert_parallel=self.args.use_expert_parallel,
819-
ema_coef=self.args.zcc_save_ema_coef,
820-
)
821-
for i in range(unwrapped_model.forward_pipeline_parallel_hook_capacity):
822-
unwrapped_model.register_forward_pipeline_parallel_hook(
823-
location=i, hook=self.zcc_manager.zcc_pipeline_hook
824-
)
825-
for i in range(unwrapped_model.backward_pipeline_parallel_hook_capacity):
826-
unwrapped_model.register_backward_pipeline_parallel_hook(
827-
location=i, hook=self.zcc_manager.zcc_pipeline_hook
828-
)
829-
else:
830-
pipeline_hooks_capacity = self.args.gradient_accumulation_steps
831-
self.zcc_manager = ZeroCostCheckpointManager(
832-
worker_num=self.args.zcc_workers_num,
833-
pipeline_hooks_capacity=pipeline_hooks_capacity,
834-
capacity_usage=self.args.zcc_pipeline_hooks_capacity_usage,
835-
use_expert_parallel=self.args.use_expert_parallel,
836-
ema_coef=self.args.zcc_save_ema_coef,
837-
)
838-
_callback = ZeroCostCheckpointCallback(self.args, self.zcc_manager, self.runtime_timer, self.sharding_io)
839-
self.add_callback(_callback)
905+
self._register_pipeline_hooks(unwrapped_model)
840906

841-
if resume_from_checkpoint is not None:
842-
path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
843-
path = os.path.join(resume_from_checkpoint, path).replace("optimizer", "ema")
844-
if self.args.zcc_save_ema_coef is not None and self.sharding_io is not None:
845-
success, err_msg = self.sharding_io.check_same_strategy(resume_from_checkpoint)
846-
else:
847-
success, err_msg = True, None
848-
if os.path.exists(path):
849-
if success:
850-
logger.info(f"ZCC EMA load from {path}")
851-
self.zcc_manager.set_ema_state_dict(path)
852-
else:
853-
logger.info(f"ZCC EMA does not load {path} because {err_msg}")
854-
else:
855-
logger.info(f"ZCC EMA state dict not found, in: {path}")
907+
# Add callback and handle checkpoint resumption
908+
self._setup_zcc_callback(zcc_callback_class)
909+
910+
self._handle_checkpoint_resume(resume_from_checkpoint)
856911

857-
logger.info("Create zero cost checkpoint manager done.")
912+
logger.info("Zero cost checkpoint manager created successfully.")
858913

859914
def add_non_zcc_ema_callback(self, resume_from_checkpoint):
860915
self.add_callback(NonZCCEMACallback(resume_from_checkpoint, self.args, self.sharding_io))

0 commit comments

Comments
 (0)