Skip to content

Commit e774ede

Browse files
committed
fix racing condition
1 parent f54ae56 commit e774ede

File tree

10 files changed

+113
-37
lines changed

10 files changed

+113
-37
lines changed

applications/ColossalChat/coati/distributed/comm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Any, Dict
21
import copy
2+
from typing import Any, Dict
3+
34
import ray
45
import ray.util.collective as cc
56
import torch
@@ -31,6 +32,7 @@ def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str =
3132
obj = c10d._tensor_to_object(obj, size_tensor.item())
3233
return obj
3334

35+
3436
def ray_broadcast_tensor_dict(
3537
tensor_dict: Dict[str, torch.Tensor],
3638
src: int = 0,
@@ -98,7 +100,7 @@ def pickup_rollout_task(self, num_tasks: int):
98100
queue length as data may still be generating
99101
"""
100102
ret = False
101-
if self.queue_size < self.buffer_size_limit:
103+
if self.queue_size < (self.buffer_size_limit / max(0.1, self.signals.get("sample_utilization", 1.0))):
102104
ret = True
103105
self.queue_size += num_tasks
104106
return ret

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
250250
input_ids_forward_micro_batch = data["input_ids"][
251251
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
252252
]
253+
old_action_log_probs_micro_batch = old_action_log_probs[
254+
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
255+
]
253256
attention_mask_forward_micro_batch = data["attention_mask"][
254257
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
255258
]
@@ -306,17 +309,22 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
306309
"action_mask": action_mask_forward_micro_batch,
307310
"advantages": advantages_forward_micro_batch,
308311
"loss_mask": loss_mask_forward_micro_batch,
312+
"old_action_log_probs": old_action_log_probs_micro_batch,
309313
"source": self.rank,
310314
}
311315
if reference_action_log_probs is not None:
312316
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
313317

314318
kl = []
315-
policy_model_logits = torch.empty_like(input_ids_forward_micro_batch, device=self.device)
316319

317320
def _criterion(outputs, inputs):
318321
action_logits = outputs.logits
319-
policy_model_logits.copy_(action_logits)
322+
mini_batch_entropies.append(
323+
(
324+
((entropy_from_logits(action_logits[:, -num_action:]) * inputs["action_mask"]).sum(-1))
325+
/ inputs["action_mask"].sum(-1)
326+
).detach()
327+
)
320328
action_log_probs = memory_efficient_logprob(
321329
action_logits / self.generate_config["temperature"],
322330
inputs["input_ids"],
@@ -339,7 +347,7 @@ def _criterion(outputs, inputs):
339347

340348
loss, _ = self.policy_loss_fn(
341349
action_log_probs,
342-
action_log_probs,
350+
inputs["old_action_log_probs"],
343351
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
344352
per_token_kl,
345353
inputs["action_mask"],
@@ -415,7 +423,7 @@ def _criterion(outputs, inputs):
415423

416424
loss, _ = self.policy_loss_fn(
417425
action_log_probs,
418-
old_action_log_probs,
426+
old_action_log_probs_micro_batch,
419427
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
420428
per_token_kl,
421429
action_mask_forward_micro_batch,
@@ -455,7 +463,7 @@ def _criterion(outputs, inputs):
455463
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
456464
advantages = all_reduce_mean(advantages.mean(), self.plugin)
457465
response_length = all_reduce_mean(response_length.mean(), self.plugin)
458-
entropy = torch.cat(mini_batch_entropies, dim=0).mean()
466+
entropy = all_reduce_mean(torch.cat(mini_batch_entropies, dim=0).mean(), self.plugin)
459467
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
460468
self.accum_entropy.add_(entropy.data)
461469
if self.policy_loss_fn.beta > 0:

applications/ColossalChat/coati/distributed/inference_backend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
generate_config: Dict[str, Any],
6060
tokenizer: PreTrainedTokenizer,
6161
num_generations: int = 8,
62+
tokenizer_config: Dict[str, Any] = None,
6263
):
6364
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
6465
model_config.update(self.FORCE_MODEL_CONFIG)
@@ -132,6 +133,7 @@ def __init__(
132133
generate_config: Dict[str, Any],
133134
tokenizer: PreTrainedTokenizer,
134135
num_generations: int = 8,
136+
tokenizer_config: Dict[str, Any] = None,
135137
):
136138
if sgl is None:
137139
raise ImportError("sglang is not installed")
@@ -196,12 +198,14 @@ def __init__(
196198
generate_config: Dict[str, Any],
197199
tokenizer: PreTrainedTokenizer,
198200
num_generations: int = 8,
201+
tokenizer_config: Dict[str, Any] = None,
199202
):
200203
if LLM is None:
201204
raise ImportError("vllm is not installed")
202205
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
203206
path = model_config.pop("path")
204-
self.llm = LLM(model=path, **model_config)
207+
tokenizer_path = tokenizer_config.get("path", None) if tokenizer_config is not None else None
208+
self.llm = LLM(model=path, tokenizer=tokenizer_path, **model_config)
205209
generate_config = generate_config.copy()
206210
generate_config.update(self.FORCE_GENERATE_CONFIG)
207211
generate_config.update({"n": num_generations})

applications/ColossalChat/coati/distributed/launch_zero_bubble.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def launch_distributed(
130130
train_dataset_config=train_dataset_config,
131131
model_config=inference_model_config,
132132
generate_config=generate_config,
133-
tokenizer_config=tokenizer_config,
133+
tokenizer_config=copy.deepcopy(tokenizer_config),
134134
microbatch_size=inference_microbatch_size,
135135
backend=inference_backend,
136136
num_generations=num_generations,
@@ -158,8 +158,6 @@ def launch_distributed(
158158
consumer_master_ip_address = gpu_to_ip_address[0]
159159
print(f"Use {consumer_master_ip_address} as master address for torch DDP.")
160160
consumer_procs = []
161-
if num_consumer_procs <= 1:
162-
raise ValueError("Number of consumer processes should be greater than 1 for async rl training.")
163161
for i in range(num_consumer_procs):
164162
node_id = gpu_to_node_id[0]
165163
consumer_ip_address = gpu_to_ip_address[0]
@@ -180,6 +178,7 @@ def launch_distributed(
180178
model_config=train_model_config,
181179
plugin_config=plugin_config,
182180
minibatch_size=train_minibatch_size,
181+
tokenizer_config=copy.deepcopy(tokenizer_config),
183182
generate_config=generate_config_consumer,
184183
grpo_config=grpo_config,
185184
num_generations=num_generations,

applications/ColossalChat/coati/distributed/loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ def forward(
3535
total_effective_tokens_in_batch: torch.Tensor = None,
3636
) -> torch.Tensor:
3737
if action_mask is None:
38-
ratio = (log_probs - log_probs.detach()).exp()
38+
ratio = (log_probs - old_log_probs.detach()).exp()
3939
else:
40-
ratio = ((log_probs - log_probs.detach()) * action_mask).exp()
40+
ratio = ((log_probs - old_log_probs.detach()) * action_mask).exp()
4141

4242
surr1 = ratio * advantages
4343
surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages

applications/ColossalChat/coati/distributed/zero_bubble/consumer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,16 @@
77
import ray.util.collective as cc
88
import torch
99
import torch.distributed as dist
10+
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
1011
from coati.distributed.profiling_utils import CustomProfiler
12+
from coati.distributed.utils import bind_batch, post_recv, unbind_batch
1113
from tqdm import tqdm
1214

1315
from colossalai.booster import Booster
1416
from colossalai.booster.plugin import HybridParallelPlugin
1517
from colossalai.initialize import launch
1618
from colossalai.utils import get_current_device
1719

18-
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
19-
from coati.distributed.utils import bind_batch, post_recv, unbind_batch
20-
2120

2221
class BaseConsumer:
2322
def __init__(
@@ -175,14 +174,15 @@ def loop(self) -> None:
175174
raw_batch = ray.get(
176175
self.shared_sync_data_actor.get_data.remote(self.data_uid)
177176
) # get the first queued data
177+
self.profiler.log(f"enter sleep")
178178
while raw_batch is None:
179-
self.profiler.log(f"No data received by consumer {self.rank}, skipping")
180179
print(
181180
f"[T{dist.get_rank()}] No data received by consumer {self.rank}, skipping. Consider increasing the data actor buffer limit"
182181
)
183182
time.sleep(1)
184183
raw_batch = ray.get(self.shared_sync_data_actor.get_data.remote(self.data_uid))
185184
continue
185+
self.profiler.log(f"exit sleep")
186186
self.data_uid += 1
187187
raw_batch = {k: v.to(self.device) for k, v in raw_batch.items()}
188188
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),

applications/ColossalChat/coati/distributed/zero_bubble/distributor.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
import ray
44
import ray.util.collective as cc
55
import torch
6+
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
67
from coati.distributed.profiling_utils import CustomProfiler
78

89
from colossalai.utils import get_current_device
910

10-
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
11-
1211

1312
@ray.remote
1413
class Distributor:
@@ -21,6 +20,7 @@ def __init__(
2120
enable_profiling: bool = True,
2221
):
2322
self.distributor_id = distributor_id
23+
self.weight_version = [0] * consumer_pp_size
2424
self.consumer_pp_size = consumer_pp_size
2525
self.state_dict_cpu = {}
2626
self.num_producers = num_producers
@@ -42,14 +42,17 @@ def init_collective_group(
4242
print(f"[D] Initialized {group_name} collective group", flush=True)
4343

4444
def loop(self):
45+
last_weight_version = self.get_weight_version()
4546
while True:
4647
time.sleep(1)
4748
signal = ray.get(self.shared_signal_actor.get_signal.remote())
4849
if self.consumer_pp_size > 1:
49-
for i in range(self.consumer_pp_size):
50-
if signal.get(f"consumer_pp_{i}", None) == "ready_sync_model":
50+
if all(
51+
[signal.get(f"consumer_pp_{i}", None) == "ready_sync_model" for i in range(self.consumer_pp_size)]
52+
):
53+
cc.barrier(group_name="distributor_pg")
54+
for i in range(self.consumer_pp_size):
5155
self.profiler.enter(f"sync_model_consumer_pp_{i}")
52-
cc.barrier(group_name="distributor_pg")
5356
ray.get(self.shared_signal_actor.set_signal.remote(f"consumer_pp_{i}", "not_ready_sync_model"))
5457
# Broadcast the model state dict from consumer to shared variable actor
5558
self.state_dict_cpu[i] = ray_broadcast_tensor_dict(
@@ -60,6 +63,7 @@ def loop(self):
6063
backend="gloo",
6164
)
6265
self.profiler.exit(f"sync_model_consumer_pp_{i}")
66+
self.weight_version[i] += 1
6367
for i in range(self.consumer_pp_size):
6468
if signal.get(f"producer_{self.distributor_id}_pp_{i}", None) == "ready_sync_model":
6569
self.profiler.enter(f"sync_model_producer_{self.distributor_id}_pp_{i}")
@@ -87,6 +91,7 @@ def loop(self):
8791
None, 0, device=torch.device("cpu"), group_name="sync_model_consumer", backend="gloo"
8892
)
8993
self.profiler.exit("sync_model_consumer")
94+
self.weight_version[0] += 1
9095
if signal.get(f"producer_{self.distributor_id}", None) == "ready_sync_model":
9196
self.profiler.enter(f"sync_model_producer_{self.distributor_id}")
9297
# Broadcast the model state dict to all producers
@@ -106,3 +111,9 @@ def loop(self):
106111
if signal.get("consumer", None) == "terminate":
107112
self.profiler.log("terminate sync model worker")
108113
break
114+
if last_weight_version != self.get_weight_version():
115+
last_weight_version = self.get_weight_version()
116+
ray.get(self.shared_signal_actor.set_signal.remote("distributor_weight_version", last_weight_version))
117+
118+
def get_weight_version(self):
119+
return min(self.weight_version)

0 commit comments

Comments
 (0)