188188
189189try :
190190 from .utils .zero_cost_checkpoint import (
191+ DistInfoCollectorValidator ,
191192 NonZCCEMACallback ,
192193 ZeroCostCheckpointCallback ,
194+ ZeroCostCheckpointCallbackFcBased ,
193195 ZeroCostCheckpointManager ,
194- get_fused_param_mappings ,
196+ ZeroCostCheckpointWorker ,
197+ ZeroCostCheckpointWorkerFcBased ,
195198 )
196199except (ImportError , ModuleNotFoundError ):
197- ZeroCostCheckpointManager , NonZCCEMACallback , get_fused_param_mappings = None , None , None
200+ ZeroCostCheckpointManager , NonZCCEMACallback = None , None
201+
198202from .utils .helper import ( # nested_truncate,
199203 broadcast_dataset_rank0_model ,
200204 broadcast_dp_optimizer ,
@@ -796,6 +800,90 @@ def _wrap_model_and_load_sharded_checkpoint(self, resume_from_checkpoint):
796800 self ._load_from_checkpoint (resume_from_checkpoint )
797801 return model
798802
803+ def _get_zcc_implementation_classes (self ):
804+ """Get appropriate ZCC implementation classes based on checkpoint format."""
805+ if self .args .save_checkpoint_format == "flex_checkpoint" :
806+ return ZeroCostCheckpointCallbackFcBased , ZeroCostCheckpointWorkerFcBased
807+ return ZeroCostCheckpointCallback , ZeroCostCheckpointWorker
808+
809+ def _create_zcc_manager_instance (self , unwrapped_model , zcc_worker_class ):
810+ """Create ZCC manager instance with appropriate configuration."""
811+ if isinstance (self .model , PipelineLayer ):
812+ pipeline_hooks_capacity = (
813+ unwrapped_model .forward_pipeline_parallel_hook_capacity
814+ + unwrapped_model .backward_pipeline_parallel_hook_capacity
815+ )
816+ else :
817+ pipeline_hooks_capacity = self .args .gradient_accumulation_steps
818+
819+ return ZeroCostCheckpointManager (
820+ worker_num = self .args .zcc_workers_num ,
821+ pipeline_hooks_capacity = pipeline_hooks_capacity ,
822+ capacity_usage = self .args .zcc_pipeline_hooks_capacity_usage ,
823+ use_expert_parallel = self .args .use_expert_parallel ,
824+ ema_coef = self .args .zcc_save_ema_coef ,
825+ zcc_worker_class = zcc_worker_class ,
826+ )
827+
828+ def _register_pipeline_hooks (self , unwrapped_model ):
829+ """Register forward and backward pipeline hooks."""
830+ # Register forward hooks
831+ for i in range (unwrapped_model .forward_pipeline_parallel_hook_capacity ):
832+ unwrapped_model .register_forward_pipeline_parallel_hook (
833+ location = i , hook = self .zcc_manager .zcc_pipeline_hook
834+ )
835+
836+ # Register backward hooks
837+ for i in range (unwrapped_model .backward_pipeline_parallel_hook_capacity ):
838+ unwrapped_model .register_backward_pipeline_parallel_hook (
839+ location = i , hook = self .zcc_manager .zcc_pipeline_hook
840+ )
841+
842+ def _setup_zcc_callback (self , zcc_callback_class ):
843+ """Setup ZCC callback with required dependencies."""
844+ callback = zcc_callback_class (self .args , self .zcc_manager , self .runtime_timer , self .sharding_io )
845+ self .add_callback (callback )
846+
847+ def _handle_checkpoint_resume (self , resume_from_checkpoint ):
848+ """Handle resumption from previous checkpoint if provided."""
849+ if resume_from_checkpoint is None :
850+ return
851+
852+ ema_state_path = self ._get_ema_state_path (resume_from_checkpoint )
853+
854+ if not os .path .exists (ema_state_path ):
855+ logger .info (f"ZCC EMA state dict not found at: { ema_state_path } " )
856+ return
857+
858+ # Validate distributed strategy compatibility
859+ should_load_ema = self ._should_load_ema_state (resume_from_checkpoint , ema_state_path )
860+
861+ if should_load_ema :
862+ logger .info (f"Loading ZCC EMA state from: { ema_state_path } " )
863+ self .zcc_manager .set_ema_state_dict (ema_state_path )
864+
865+ def _get_ema_state_path (self , checkpoint_path ):
866+ """Get the path to EMA state based on checkpoint format."""
867+ if self .args .save_checkpoint_format == "flex_checkpoint" :
868+ return os .path .join (checkpoint_path , "ema_state" , f"{ dist .get_rank ()} _0.distcp" )
869+ else :
870+ optimizer_name = _add_variant (PADDLE_OPTIMIZER_NAME , self .args .optimizer_name_suffix )
871+ return os .path .join (checkpoint_path , optimizer_name ).replace ("optimizer" , "ema" )
872+
873+ def _should_load_ema_state (self , checkpoint_path , ema_state_path ):
874+ """Determine if EMA state should be loaded based on configuration and compatibility."""
875+ if self .args .zcc_save_ema_coef is None :
876+ logger .info ("EMA coefficient is None, skipping EMA state loading" )
877+ return False
878+
879+ success , err_msg = DistInfoCollectorValidator (self .args , self .hcg ).check_same_strategy (checkpoint_path )
880+
881+ if not success :
882+ logger .warning (f"Cannot load EMA state due to strategy mismatch: { err_msg } " )
883+ return False
884+
885+ return True
886+
799887 def create_zcc_manager (self , unwrapped_model , resume_from_checkpoint = None ):
800888 """
801889 Create zero cost checkpoint manager.
@@ -806,55 +894,22 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None):
806894 self .model , PretrainedModel
807895 ), "model should be a PretrainedModel when using zero cost checkpoint"
808896 logger .info ("Create zero cost checkpoint manager..." )
897+
898+ zcc_callback_class , zcc_worker_class = self ._get_zcc_implementation_classes ()
899+
900+ # Create ZCC manager with appropriate configuration
901+ self .zcc_manager = self ._create_zcc_manager_instance (unwrapped_model , zcc_worker_class )
902+
903+ # Register pipeline hooks if using pipeline parallelism
809904 if isinstance (self .model , PipelineLayer ):
810- pipeline_hooks_capacity = (
811- unwrapped_model .forward_pipeline_parallel_hook_capacity
812- + unwrapped_model .backward_pipeline_parallel_hook_capacity
813- )
814- self .zcc_manager = ZeroCostCheckpointManager (
815- worker_num = self .args .zcc_workers_num ,
816- pipeline_hooks_capacity = pipeline_hooks_capacity ,
817- capacity_usage = self .args .zcc_pipeline_hooks_capacity_usage ,
818- use_expert_parallel = self .args .use_expert_parallel ,
819- ema_coef = self .args .zcc_save_ema_coef ,
820- )
821- for i in range (unwrapped_model .forward_pipeline_parallel_hook_capacity ):
822- unwrapped_model .register_forward_pipeline_parallel_hook (
823- location = i , hook = self .zcc_manager .zcc_pipeline_hook
824- )
825- for i in range (unwrapped_model .backward_pipeline_parallel_hook_capacity ):
826- unwrapped_model .register_backward_pipeline_parallel_hook (
827- location = i , hook = self .zcc_manager .zcc_pipeline_hook
828- )
829- else :
830- pipeline_hooks_capacity = self .args .gradient_accumulation_steps
831- self .zcc_manager = ZeroCostCheckpointManager (
832- worker_num = self .args .zcc_workers_num ,
833- pipeline_hooks_capacity = pipeline_hooks_capacity ,
834- capacity_usage = self .args .zcc_pipeline_hooks_capacity_usage ,
835- use_expert_parallel = self .args .use_expert_parallel ,
836- ema_coef = self .args .zcc_save_ema_coef ,
837- )
838- _callback = ZeroCostCheckpointCallback (self .args , self .zcc_manager , self .runtime_timer , self .sharding_io )
839- self .add_callback (_callback )
905+ self ._register_pipeline_hooks (unwrapped_model )
840906
841- if resume_from_checkpoint is not None :
842- path = _add_variant (PADDLE_OPTIMIZER_NAME , self .args .optimizer_name_suffix )
843- path = os .path .join (resume_from_checkpoint , path ).replace ("optimizer" , "ema" )
844- if self .args .zcc_save_ema_coef is not None and self .sharding_io is not None :
845- success , err_msg = self .sharding_io .check_same_strategy (resume_from_checkpoint )
846- else :
847- success , err_msg = True , None
848- if os .path .exists (path ):
849- if success :
850- logger .info (f"ZCC EMA load from { path } " )
851- self .zcc_manager .set_ema_state_dict (path )
852- else :
853- logger .info (f"ZCC EMA does not load { path } because { err_msg } " )
854- else :
855- logger .info (f"ZCC EMA state dict not found, in: { path } " )
907+ # Add callback and handle checkpoint resumption
908+ self ._setup_zcc_callback (zcc_callback_class )
909+
910+ self ._handle_checkpoint_resume (resume_from_checkpoint )
856911
857- logger .info ("Create zero cost checkpoint manager done ." )
912+ logger .info ("Zero cost checkpoint manager created successfully ." )
858913
859914 def add_non_zcc_ema_callback (self , resume_from_checkpoint ):
860915 self .add_callback (NonZCCEMACallback (resume_from_checkpoint , self .args , self .sharding_io ))
0 commit comments