Skip to content

Commit e96f19f

Browse files
authored
fix(embodied): fix hang when rollout world size is greater than actor's (RLinf#487)
Signed-off-by: Bo Dai <daibo@infini-ai.com>
1 parent 53d89ce commit e96f19f

File tree

2 files changed

+53
-14
lines changed

2 files changed

+53
-14
lines changed

rlinf/workers/actor/fsdp_actor_worker.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919
import torch
2020
from omegaconf import DictConfig
21+
from torch import nn
2122
from torch.distributed.tensor import DTensor
2223
from torch.multiprocessing.reductions import reduce_tensor
2324

@@ -660,43 +661,64 @@ def __init__(self, cfg: DictConfig):
660661
self._env_group_name = cfg.env.group_name
661662
self._rollout_group_name = cfg.rollout.group_name
662663
self._component_placement = HybridComponentPlacement(cfg, Cluster())
663-
self._weight_dst_rank_in_rollout = self._rank
664-
if self._weight_dst_rank_in_rollout >= self._component_placement.get_world_size(
665-
"rollout"
666-
):
667-
self._weight_dst_rank_in_rollout = None
668664

669665
# stage_num: default to 2, use for pipeline rollout process
670666
self.stage_num = cfg.rollout.pipeline_stage_num
671667

672668
self.enable_offload = self.cfg.actor.get("enable_offload", False)
673669

674-
def init_worker(self):
670+
def _setup_rollout_weight_dst_ranks(self) -> None:
671+
"""
672+
Setup destination ranks for weight communication.
673+
It can support any topology between actor and rollout workers.
674+
Assuming there are M actor ranks and N rollout ranks, each actor rank
675+
will send weights to most ceil(N/M) rollout ranks according to the modulo rule.
676+
"""
677+
rollout_world_size = self._component_placement.get_world_size("rollout")
678+
actor_world_size = self._world_size
679+
rank = self._rank
680+
self._weight_dst_rank_in_rollout = []
681+
rollout_ranks_per_actor = (
682+
rollout_world_size + actor_world_size - 1
683+
) // actor_world_size
684+
for i in range(rollout_ranks_per_actor):
685+
if i * actor_world_size + rank < rollout_world_size:
686+
self._weight_dst_rank_in_rollout.append(i * actor_world_size + rank)
687+
688+
def init_worker(self) -> None:
689+
"""
690+
Initialize the actor worker. build the model and use corresponding training backend,
691+
if needed, offload model parameters and optimizer states to CPU.
692+
"""
675693
self.setup_model_and_optimizer()
676694

677695
if self.enable_offload:
678696
self.offload_param_and_grad()
679697
self.offload_optimizer()
698+
self._setup_rollout_weight_dst_ranks()
680699

681-
def model_provider_func(self):
700+
def model_provider_func(self) -> nn.Module:
682701
model = get_model(self.cfg.actor.model)
683702
if model is not None:
684703
return model
685704
return super().model_provider_func()
686705

687-
def sync_model_to_rollout(self):
706+
def sync_model_to_rollout(self) -> None:
707+
"""
708+
Sync the model's full state dict to the rollout worker.
709+
"""
688710
if self.enable_offload and not self.is_optimizer_offloaded:
689711
self.offload_optimizer()
690712

691713
if self.enable_offload and self.is_weight_offloaded:
692714
self.load_param_and_grad(self.device)
693715

694716
state_dict = self.get_model_state_dict(cpu_offload=False, full_state_dict=True)
695-
if self._weight_dst_rank_in_rollout is not None:
717+
for rank in self._weight_dst_rank_in_rollout:
696718
self.send(
697719
state_dict,
698720
self._rollout_group_name,
699-
self._weight_dst_rank_in_rollout,
721+
rank,
700722
async_op=True,
701723
)
702724
if self.enable_offload and not self.is_weight_offloaded:
@@ -705,6 +727,9 @@ def sync_model_to_rollout(self):
705727
def recv_rollout_batch(self, input_channel: Channel) -> None:
706728
"""
707729
Receive rollout batch from rollout workers.
730+
731+
Args:
732+
input_channel: The input channel to read from.
708733
"""
709734
send_num = self._component_placement.get_world_size("rollout") * self.stage_num
710735
recv_num = self._component_placement.get_world_size("actor")
@@ -808,7 +833,10 @@ def _process_received_rollout_batch(
808833

809834
return rollout_batch
810835

811-
def compute_advantages_and_returns(self):
836+
def compute_advantages_and_returns(self) -> dict[str, torch.Tensor]:
837+
"""
838+
Compute the advantages and returns.
839+
"""
812840
kwargs = {
813841
"task_type": self.cfg.runner.task_type,
814842
"adv_type": self.cfg.algorithm.adv_type,
@@ -834,7 +862,10 @@ def compute_advantages_and_returns(self):
834862
rollout_metrics = compute_rollout_metrics(self.rollout_batch)
835863
return rollout_metrics
836864

837-
def run_training(self):
865+
def run_training(self) -> None:
866+
"""
867+
Run the training process using the received rollout batch.
868+
"""
838869
if self.is_weight_offloaded:
839870
self.load_param_and_grad(self.device)
840871
if self.is_optimizer_offloaded:
@@ -1012,6 +1043,9 @@ def run_training(self):
10121043

10131044
return mean_metric_dict
10141045

1015-
def set_global_step(self, global_step):
1046+
def set_global_step(self, global_step) -> None:
1047+
"""
1048+
Set the global step for the model, if needed.
1049+
"""
10161050
if hasattr(self.model, "set_global_step"):
10171051
self.model.set_global_step(global_step)

rlinf/workers/rollout/hf/huggingface_worker.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ def __init__(self, cfg: DictConfig):
4141

4242
self.placement = HybridComponentPlacement(cfg, Cluster())
4343

44+
actor_world_size = self.placement.get_world_size("actor")
45+
self.actor_weight_src_rank = self._rank % actor_world_size
46+
4447
def init_worker(self):
4548
rollout_model_config = copy.deepcopy(self.cfg.actor.model)
4649
with open_dict(rollout_model_config):
@@ -160,7 +163,9 @@ def get_dones_and_rewards(
160163

161164
def sync_model_from_actor(self):
162165
"""Sync model parameters from the actor worker."""
163-
param_state_dict = self.recv(self.actor_group_name, src_rank=self._rank)
166+
param_state_dict = self.recv(
167+
self.actor_group_name, src_rank=self.actor_weight_src_rank
168+
)
164169

165170
self.hf_model.load_state_dict(param_state_dict)
166171
del param_state_dict

0 commit comments

Comments
 (0)