Skip to content

[feat] Support Zero Bubble StreamRL-like RL Training #6356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: grpo-latest
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 106 additions & 9 deletions applications/ColossalChat/coati/distributed/comm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import copy
from typing import Any, Dict

import ray
import ray.util.collective as cc
import torch
import torch.distributed.distributed_c10d as c10d
Expand Down Expand Up @@ -32,26 +34,121 @@ def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str =


def ray_broadcast_tensor_dict(
tensor_dict: Dict[str, torch.Tensor], src: int = 0, device=None, group_name: str = "default"
tensor_dict: Dict[str, torch.Tensor],
src: int = 0,
device=None,
group_name: str = "default",
backend: str = "nccl",
offload_to_cpu: bool = False,
pin_memory: bool = False,
) -> Dict[str, torch.Tensor]:
rank = cc.get_rank(group_name)
if tensor_dict is None:
tensor_dict = {}
if rank == src:
metadata = []
for k, v in tensor_dict.items():
metadata.append((k, v.shape, v.dtype))
else:
metadata = None
metadata = ray_broadcast_object(metadata, src, device, group_name)
if rank != src:
out_dict = {}
for k, shape, dtype in metadata:
if rank == src:
tensor = tensor_dict[k]
if offload_to_cpu:
tensor = tensor_dict[k].to(device)
else:
tensor = tensor_dict[k]
else:
tensor = torch.empty(shape, dtype=dtype, device=device)
tensor = tensor_dict.get(k, torch.zeros(shape, dtype=dtype, device=device, pin_memory=pin_memory))
if backend == "gloo" and dtype == torch.bfloat16:
# Gloo does not support bfloat16, convert to float16
tensor = tensor.view(torch.float16)
cc.broadcast(tensor, src, group_name)
if backend == "gloo" and dtype == torch.bfloat16:
# Convert back to bfloat16 if it was converted to float16
tensor = tensor.view(torch.bfloat16)
if rank != src:
out_dict[k] = tensor
if rank == src:
out_dict = tensor_dict
return out_dict
if offload_to_cpu:
tensor_dict[k] = tensor.cpu()
else:
tensor_dict[k] = tensor
return tensor_dict


@ray.remote
class SharedVariableActor:
def __init__(self, number_of_readers: int = 0, buffer_size_limit: int = 1000):
self.data_queue = []
self.data_uid = 0
self.number_of_readers = number_of_readers
self.queue_size = 0
self.signals = {}
self.process_locks = {}
self.signal_procs_meet_count = {}
self.buffer_size_limit = buffer_size_limit

def pickup_rollout_task(self, num_tasks: int):
"""
use queue size to control whether producers should generating new rollouts or wait
for consumer to consumer more data. if queue size is less than threshold,
it means consumer is consuming data fast enough, so producers can generate new rollouts.
if queue size is greater than threshold, it means consumer is consuming data slowly,
so producers should wait for consumer to consume more data.

Any free producer can pick up the task to generate rollout then increase the queued_data_size
to prevent other producer to pick up the task redundantly, Note it is not the real
queue length as data may still be generating
"""
ret = False
if self.queue_size < (self.buffer_size_limit / max(0.1, self.signals.get("sample_utilization", 1.0))):
ret = True
self.queue_size += num_tasks
return ret

def append_data(self, data):
self.data_queue.append([self.data_uid, data, 0]) # [data_uid, data, access_count]
self.data_uid += 1
return True

def get_data(self, data_uid: int):
# for multi-process data reading
if not self.data_queue:
# no data in the queue, return None
return None
to_pop_index = None
ret = None
for i, (uid, data, access_count) in enumerate(self.data_queue):
if uid == data_uid:
# found the data with the given uid
self.data_queue[i][2] += 1
ret = copy.deepcopy(data)
if self.data_queue[i][2] == self.number_of_readers:
to_pop_index = i
break
if to_pop_index is not None:
# remove the data from the queue if it has been accessed by all readers
self.data_queue.pop(to_pop_index)
self.queue_size -= data["input_ids"].size(0)
return ret

def acquire_process_lock(self, key: str):
# atomic lock for process
if key not in self.process_locks:
self.process_locks[key] = 1 # locked
return 0
if self.process_locks[key] == 0:
self.process_locks[key] = 1 # lock the process
return 0
else:
return 1

def release_process_lock(self, key: str):
# atomic unlock for process
assert self.process_locks.get(key, 0) == 1, f"Releasing a process lock {key} that is not locked."
self.process_locks[key] = 0

def set_signal(self, key: str, signal: str):
self.signals[key] = signal

def get_signal(self):
return self.signals
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
tokenizer_config: Dict[str, Any] = None,
):
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
model_config.update(self.FORCE_MODEL_CONFIG)
Expand Down Expand Up @@ -132,6 +133,7 @@ def __init__(
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
tokenizer_config: Dict[str, Any] = None,
):
if sgl is None:
raise ImportError("sglang is not installed")
Expand Down Expand Up @@ -196,12 +198,14 @@ def __init__(
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
tokenizer_config: Dict[str, Any] = None,
):
if LLM is None:
raise ImportError("vllm is not installed")
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
path = model_config.pop("path")
self.llm = LLM(model=path, **model_config)
tokenizer_path = tokenizer_config.get("path", None) if tokenizer_config is not None else None
self.llm = LLM(model=path, tokenizer=tokenizer_path, **model_config)
generate_config = generate_config.copy()
generate_config.update(self.FORCE_GENERATE_CONFIG)
generate_config.update({"n": num_generations})
Expand Down
Loading