1818import numpy as np
1919import torch
2020from omegaconf import DictConfig
21+ from torch import nn
2122from torch .distributed .tensor import DTensor
2223from torch .multiprocessing .reductions import reduce_tensor
2324
@@ -660,43 +661,64 @@ def __init__(self, cfg: DictConfig):
660661 self ._env_group_name = cfg .env .group_name
661662 self ._rollout_group_name = cfg .rollout .group_name
662663 self ._component_placement = HybridComponentPlacement (cfg , Cluster ())
663- self ._weight_dst_rank_in_rollout = self ._rank
664- if self ._weight_dst_rank_in_rollout >= self ._component_placement .get_world_size (
665- "rollout"
666- ):
667- self ._weight_dst_rank_in_rollout = None
668664
669665 # stage_num: default to 2, use for pipeline rollout process
670666 self .stage_num = cfg .rollout .pipeline_stage_num
671667
672668 self .enable_offload = self .cfg .actor .get ("enable_offload" , False )
673669
674- def init_worker (self ):
670+ def _setup_rollout_weight_dst_ranks (self ) -> None :
671+ """
672+ Setup destination ranks for weight communication.
673+ It can support any topology between actor and rollout workers.
674+ Assuming there are M actor ranks and N rollout ranks, each actor rank
675+ will send weights to most ceil(N/M) rollout ranks according to the modulo rule.
676+ """
677+ rollout_world_size = self ._component_placement .get_world_size ("rollout" )
678+ actor_world_size = self ._world_size
679+ rank = self ._rank
680+ self ._weight_dst_rank_in_rollout = []
681+ rollout_ranks_per_actor = (
682+ rollout_world_size + actor_world_size - 1
683+ ) // actor_world_size
684+ for i in range (rollout_ranks_per_actor ):
685+ if i * actor_world_size + rank < rollout_world_size :
686+ self ._weight_dst_rank_in_rollout .append (i * actor_world_size + rank )
687+
688+ def init_worker (self ) -> None :
689+ """
690+ Initialize the actor worker. build the model and use corresponding training backend,
691+ if needed, offload model parameters and optimizer states to CPU.
692+ """
675693 self .setup_model_and_optimizer ()
676694
677695 if self .enable_offload :
678696 self .offload_param_and_grad ()
679697 self .offload_optimizer ()
698+ self ._setup_rollout_weight_dst_ranks ()
680699
681- def model_provider_func (self ):
700+ def model_provider_func (self ) -> nn . Module :
682701 model = get_model (self .cfg .actor .model )
683702 if model is not None :
684703 return model
685704 return super ().model_provider_func ()
686705
687- def sync_model_to_rollout (self ):
706+ def sync_model_to_rollout (self ) -> None :
707+ """
708+ Sync the model's full state dict to the rollout worker.
709+ """
688710 if self .enable_offload and not self .is_optimizer_offloaded :
689711 self .offload_optimizer ()
690712
691713 if self .enable_offload and self .is_weight_offloaded :
692714 self .load_param_and_grad (self .device )
693715
694716 state_dict = self .get_model_state_dict (cpu_offload = False , full_state_dict = True )
695- if self ._weight_dst_rank_in_rollout is not None :
717+ for rank in self ._weight_dst_rank_in_rollout :
696718 self .send (
697719 state_dict ,
698720 self ._rollout_group_name ,
699- self . _weight_dst_rank_in_rollout ,
721+ rank ,
700722 async_op = True ,
701723 )
702724 if self .enable_offload and not self .is_weight_offloaded :
@@ -705,6 +727,9 @@ def sync_model_to_rollout(self):
705727 def recv_rollout_batch (self , input_channel : Channel ) -> None :
706728 """
707729 Receive rollout batch from rollout workers.
730+
731+ Args:
732+ input_channel: The input channel to read from.
708733 """
709734 send_num = self ._component_placement .get_world_size ("rollout" ) * self .stage_num
710735 recv_num = self ._component_placement .get_world_size ("actor" )
@@ -808,7 +833,10 @@ def _process_received_rollout_batch(
808833
809834 return rollout_batch
810835
811- def compute_advantages_and_returns (self ):
836+ def compute_advantages_and_returns (self ) -> dict [str , torch .Tensor ]:
837+ """
838+ Compute the advantages and returns.
839+ """
812840 kwargs = {
813841 "task_type" : self .cfg .runner .task_type ,
814842 "adv_type" : self .cfg .algorithm .adv_type ,
@@ -834,7 +862,10 @@ def compute_advantages_and_returns(self):
834862 rollout_metrics = compute_rollout_metrics (self .rollout_batch )
835863 return rollout_metrics
836864
837- def run_training (self ):
865+ def run_training (self ) -> None :
866+ """
867+ Run the training process using the received rollout batch.
868+ """
838869 if self .is_weight_offloaded :
839870 self .load_param_and_grad (self .device )
840871 if self .is_optimizer_offloaded :
@@ -1012,6 +1043,9 @@ def run_training(self):
10121043
10131044 return mean_metric_dict
10141045
1015- def set_global_step (self , global_step ):
1046+ def set_global_step (self , global_step ) -> None :
1047+ """
1048+ Set the global step for the model, if needed.
1049+ """
10161050 if hasattr (self .model , "set_global_step" ):
10171051 self .model .set_global_step (global_step )
0 commit comments