Skip to content

Commit 2336d7f

Browse files
committed
fix racing condition
1 parent ddda79c commit 2336d7f

File tree

10 files changed

+100
-33
lines changed

10 files changed

+100
-33
lines changed

applications/ColossalChat/coati/distributed/comm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Any, Dict
21
import copy
2+
from typing import Any, Dict
3+
34
import ray
45
import ray.util.collective as cc
56
import torch
@@ -31,6 +32,7 @@ def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str =
3132
obj = c10d._tensor_to_object(obj, size_tensor.item())
3233
return obj
3334

35+
3436
def ray_broadcast_tensor_dict(
3537
tensor_dict: Dict[str, torch.Tensor],
3638
src: int = 0,
@@ -98,7 +100,7 @@ def pickup_rollout_task(self, num_tasks: int):
98100
queue length as data may still be generating
99101
"""
100102
ret = False
101-
if self.queue_size < self.buffer_size_limit:
103+
if self.queue_size < (self.buffer_size_limit / max(0.1, self.signals.get("sample_utilization", 1.0))):
102104
ret = True
103105
self.queue_size += num_tasks
104106
return ret

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,6 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
346346
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
347347

348348
kl = []
349-
policy_model_logits = torch.empty_like(input_ids_forward_micro_batch, device=self.device)
350349

351350
def _criterion(outputs, inputs):
352351
action_logits = outputs.logits

applications/ColossalChat/coati/distributed/inference_backend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
generate_config: Dict[str, Any],
6060
tokenizer: PreTrainedTokenizer,
6161
num_generations: int = 8,
62+
tokenizer_config: Dict[str, Any] = None,
6263
):
6364
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
6465
model_config.update(self.FORCE_MODEL_CONFIG)
@@ -132,6 +133,7 @@ def __init__(
132133
generate_config: Dict[str, Any],
133134
tokenizer: PreTrainedTokenizer,
134135
num_generations: int = 8,
136+
tokenizer_config: Dict[str, Any] = None,
135137
):
136138
if sgl is None:
137139
raise ImportError("sglang is not installed")
@@ -196,12 +198,14 @@ def __init__(
196198
generate_config: Dict[str, Any],
197199
tokenizer: PreTrainedTokenizer,
198200
num_generations: int = 8,
201+
tokenizer_config: Dict[str, Any] = None,
199202
):
200203
if LLM is None:
201204
raise ImportError("vllm is not installed")
202205
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
203206
path = model_config.pop("path")
204-
self.llm = LLM(model=path, **model_config)
207+
tokenizer_path = tokenizer_config.get("path", None) if tokenizer_config is not None else None
208+
self.llm = LLM(model=path, tokenizer=tokenizer_path, **model_config)
205209
generate_config = generate_config.copy()
206210
generate_config.update(self.FORCE_GENERATE_CONFIG)
207211
generate_config.update({"n": num_generations})

applications/ColossalChat/coati/distributed/launch_zero_bubble.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def launch_distributed(
130130
train_dataset_config=train_dataset_config,
131131
model_config=inference_model_config,
132132
generate_config=generate_config,
133-
tokenizer_config=tokenizer_config,
133+
tokenizer_config=copy.deepcopy(tokenizer_config),
134134
microbatch_size=inference_microbatch_size,
135135
backend=inference_backend,
136136
num_generations=num_generations,
@@ -158,8 +158,6 @@ def launch_distributed(
158158
consumer_master_ip_address = gpu_to_ip_address[0]
159159
print(f"Use {consumer_master_ip_address} as master address for torch DDP.")
160160
consumer_procs = []
161-
if num_consumer_procs <= 1:
162-
raise ValueError("Number of consumer processes should be greater than 1 for async rl training.")
163161
for i in range(num_consumer_procs):
164162
node_id = gpu_to_node_id[0]
165163
consumer_ip_address = gpu_to_ip_address[0]
@@ -180,6 +178,7 @@ def launch_distributed(
180178
model_config=train_model_config,
181179
plugin_config=plugin_config,
182180
minibatch_size=train_minibatch_size,
181+
tokenizer_config=copy.deepcopy(tokenizer_config),
183182
generate_config=generate_config_consumer,
184183
grpo_config=grpo_config,
185184
num_generations=num_generations,

applications/ColossalChat/coati/distributed/loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ def forward(
3737
total_effective_tokens_in_batch: torch.Tensor = None,
3838
) -> torch.Tensor:
3939
if action_mask is None:
40-
ratio = (log_probs - log_probs.detach()).exp()
40+
ratio = (log_probs - old_log_probs.detach()).exp()
4141
else:
42-
ratio = ((log_probs - log_probs.detach()) * action_mask).exp()
42+
ratio = ((log_probs - old_log_probs.detach()) * action_mask).exp()
4343

4444
surr1 = ratio * advantages
4545
surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages

applications/ColossalChat/coati/distributed/zero_bubble/consumer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,16 @@
77
import ray.util.collective as cc
88
import torch
99
import torch.distributed as dist
10+
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
1011
from coati.distributed.profiling_utils import CustomProfiler
12+
from coati.distributed.utils import bind_batch, post_recv, unbind_batch
1113
from tqdm import tqdm
1214

1315
from colossalai.booster import Booster
1416
from colossalai.booster.plugin import HybridParallelPlugin
1517
from colossalai.initialize import launch
1618
from colossalai.utils import get_current_device
1719

18-
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
19-
from coati.distributed.utils import bind_batch, post_recv, unbind_batch
20-
2120

2221
class BaseConsumer:
2322
def __init__(
@@ -175,14 +174,15 @@ def loop(self) -> None:
175174
raw_batch = ray.get(
176175
self.shared_sync_data_actor.get_data.remote(self.data_uid)
177176
) # get the first queued data
177+
self.profiler.log(f"enter sleep")
178178
while raw_batch is None:
179-
self.profiler.log(f"No data received by consumer {self.rank}, skipping")
180179
print(
181180
f"[T{dist.get_rank()}] No data received by consumer {self.rank}, skipping. Consider increasing the data actor buffer limit"
182181
)
183182
time.sleep(1)
184183
raw_batch = ray.get(self.shared_sync_data_actor.get_data.remote(self.data_uid))
185184
continue
185+
self.profiler.log(f"exit sleep")
186186
self.data_uid += 1
187187
raw_batch = {k: v.to(self.device) for k, v in raw_batch.items()}
188188
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),

applications/ColossalChat/coati/distributed/zero_bubble/distributor.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
import ray
44
import ray.util.collective as cc
55
import torch
6+
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
67
from coati.distributed.profiling_utils import CustomProfiler
78

89
from colossalai.utils import get_current_device
910

10-
from coati.distributed.comm import SharedVariableActor, ray_broadcast_tensor_dict
11-
1211

1312
@ray.remote
1413
class Distributor:
@@ -21,6 +20,7 @@ def __init__(
2120
enable_profiling: bool = True,
2221
):
2322
self.distributor_id = distributor_id
23+
self.weight_version = [0] * consumer_pp_size
2424
self.consumer_pp_size = consumer_pp_size
2525
self.state_dict_cpu = {}
2626
self.num_producers = num_producers
@@ -42,14 +42,17 @@ def init_collective_group(
4242
print(f"[D] Initialized {group_name} collective group", flush=True)
4343

4444
def loop(self):
45+
last_weight_version = self.get_weight_version()
4546
while True:
4647
time.sleep(1)
4748
signal = ray.get(self.shared_signal_actor.get_signal.remote())
4849
if self.consumer_pp_size > 1:
49-
for i in range(self.consumer_pp_size):
50-
if signal.get(f"consumer_pp_{i}", None) == "ready_sync_model":
50+
if all(
51+
[signal.get(f"consumer_pp_{i}", None) == "ready_sync_model" for i in range(self.consumer_pp_size)]
52+
):
53+
cc.barrier(group_name="distributor_pg")
54+
for i in range(self.consumer_pp_size):
5155
self.profiler.enter(f"sync_model_consumer_pp_{i}")
52-
cc.barrier(group_name="distributor_pg")
5356
ray.get(self.shared_signal_actor.set_signal.remote(f"consumer_pp_{i}", "not_ready_sync_model"))
5457
# Broadcast the model state dict from consumer to shared variable actor
5558
self.state_dict_cpu[i] = ray_broadcast_tensor_dict(
@@ -60,6 +63,7 @@ def loop(self):
6063
backend="gloo",
6164
)
6265
self.profiler.exit(f"sync_model_consumer_pp_{i}")
66+
self.weight_version[i] += 1
6367
for i in range(self.consumer_pp_size):
6468
if signal.get(f"producer_{self.distributor_id}_pp_{i}", None) == "ready_sync_model":
6569
self.profiler.enter(f"sync_model_producer_{self.distributor_id}_pp_{i}")
@@ -87,6 +91,7 @@ def loop(self):
8791
None, 0, device=torch.device("cpu"), group_name="sync_model_consumer", backend="gloo"
8892
)
8993
self.profiler.exit("sync_model_consumer")
94+
self.weight_version[0] += 1
9095
if signal.get(f"producer_{self.distributor_id}", None) == "ready_sync_model":
9196
self.profiler.enter(f"sync_model_producer_{self.distributor_id}")
9297
# Broadcast the model state dict to all producers
@@ -106,3 +111,9 @@ def loop(self):
106111
if signal.get("consumer", None) == "terminate":
107112
self.profiler.log("terminate sync model worker")
108113
break
114+
if last_weight_version != self.get_weight_version():
115+
last_weight_version = self.get_weight_version()
116+
ray.get(self.shared_signal_actor.set_signal.remote("distributor_weight_version", last_weight_version))
117+
118+
def get_weight_version(self):
119+
return min(self.weight_version)

applications/ColossalChat/coati/distributed/zero_bubble/grpo_consumer.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
import torch
66
import wandb
77
from coati.distributed.comm import SharedVariableActor
8-
from coati.distributed.zero_bubble.consumer import BaseConsumer
98
from coati.distributed.loss import PolicyLoss
10-
from coati.distributed.utils import memory_efficient_logprob
9+
from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob
10+
from coati.distributed.zero_bubble.consumer import BaseConsumer
1111
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
1212
from transformers import AutoModelForCausalLM, AutoTokenizer
1313

@@ -33,6 +33,7 @@ def __init__(
3333
plugin_config,
3434
minibatch_size=1,
3535
num_generations=8,
36+
tokenizer_config=None,
3637
generate_config=None,
3738
grpo_config={},
3839
save_interval: int = 100,
@@ -73,9 +74,11 @@ def __init__(
7374
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
7475
self.policy_model.train()
7576
self.policy_model.gradient_checkpointing_enable()
77+
self.vocab_size = self.policy_model.config.vocab_size
7678
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
7779
self.accum_loss = torch.zeros(1, device=self.device)
7880
self.accum_kl = torch.zeros(1, device=self.device)
81+
self.accum_entropy = torch.zeros(1, device=self.device)
7982
self.accum_advantages = torch.zeros(1, device=self.device)
8083
self.raw_train_batch_reward = []
8184
self.raw_train_batch_format_acc = []
@@ -102,8 +105,11 @@ def __init__(
102105
if self.policy_loss_fn.beta > 0:
103106
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
104107
self.reference_model.eval()
105-
106-
self.tokenizer = AutoTokenizer.from_pretrained(path)
108+
if tokenizer_config is not None:
109+
path = tokenizer_config.pop("path", None)
110+
self.tokenizer = AutoTokenizer.from_pretrained(path, **tokenizer_config)
111+
else:
112+
self.tokenizer = AutoTokenizer.from_pretrained(path)
107113
self.pad_token_id = self.tokenizer.pad_token_id
108114
self.num_generations = num_generations
109115
self.filter_range = grpo_config.get("filter_range", None)
@@ -243,10 +249,14 @@ def step(self, pbar: Any, **kwargs) -> Optional[float]:
243249
else self.booster.no_sync(self.policy_model, self.optimizer)
244250
)
245251
with ctx:
252+
mini_batch_entropies = []
246253
for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size):
247254
input_ids_forward_micro_batch = data["input_ids"][
248255
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
249256
]
257+
old_action_log_probs_micro_batch = old_action_log_probs[
258+
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
259+
]
250260
attention_mask_forward_micro_batch = data["attention_mask"][
251261
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
252262
]
@@ -303,6 +313,7 @@ def step(self, pbar: Any, **kwargs) -> Optional[float]:
303313
"action_mask": action_mask_forward_micro_batch,
304314
"advantages": advantages_forward_micro_batch,
305315
"loss_mask": loss_mask_forward_micro_batch,
316+
"old_action_log_probs": old_action_log_probs_micro_batch,
306317
"source": self.rank,
307318
}
308319
if reference_action_log_probs is not None:
@@ -312,6 +323,12 @@ def step(self, pbar: Any, **kwargs) -> Optional[float]:
312323

313324
def _criterion(outputs, inputs):
314325
action_logits = outputs.logits
326+
mini_batch_entropies.append(
327+
(
328+
((entropy_from_logits(action_logits[:, -num_action:]) * inputs["action_mask"]).sum(-1))
329+
/ inputs["action_mask"].sum(-1)
330+
).detach()
331+
)
315332
action_log_probs = memory_efficient_logprob(
316333
action_logits / self.generate_config["temperature"],
317334
inputs["input_ids"],
@@ -334,7 +351,7 @@ def _criterion(outputs, inputs):
334351

335352
loss, _ = self.policy_loss_fn(
336353
action_log_probs,
337-
action_log_probs,
354+
inputs["old_action_log_probs"],
338355
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
339356
per_token_kl,
340357
inputs["action_mask"],
@@ -396,7 +413,7 @@ def _criterion(outputs, inputs):
396413

397414
loss, _ = self.policy_loss_fn(
398415
action_log_probs,
399-
old_action_log_probs,
416+
old_action_log_probs_micro_batch,
400417
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
401418
per_token_kl,
402419
action_mask_forward_micro_batch,
@@ -411,6 +428,20 @@ def _criterion(outputs, inputs):
411428
kl = all_reduce_mean(kl.mean(), self.plugin)
412429
mean_kl.append(kl.data)
413430
mean_loss.append(loss.data)
431+
mini_batch_entropies.append(
432+
all_reduce_mean(
433+
(
434+
(
435+
(
436+
entropy_from_logits(policy_model_logits[:, -num_action:])
437+
* action_mask_forward_micro_batch
438+
).sum(-1)
439+
)
440+
/ action_mask_forward_micro_batch.sum(-1)
441+
).detach(),
442+
self.plugin,
443+
)
444+
)
414445
if not self.plugin.pp_size > 1 or (
415446
self.plugin.pp_size > 1
416447
and self.booster.plugin.stage_manager.is_last_stage()
@@ -422,7 +453,9 @@ def _criterion(outputs, inputs):
422453
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
423454
advantages = all_reduce_mean(advantages.mean(), self.plugin)
424455
response_length = all_reduce_mean(response_length.mean(), self.plugin)
456+
entropy = all_reduce_mean(torch.cat(mini_batch_entropies, dim=0).mean(), self.plugin)
425457
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
458+
self.accum_entropy.add_(entropy.data)
426459
if self.policy_loss_fn.beta > 0:
427460
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
428461
self.accum_advantages.add_(advantages.data)
@@ -465,6 +498,7 @@ def _criterion(outputs, inputs):
465498
f"Response Length: {raw_batch_response_len_mean:.4f}",
466499
f"Sample_utilization: {sample_utilization:.4f}",
467500
f"Overlength samples ratio: {overlength_samples_ratio:.4f}",
501+
f"Entropy: {self.accum_entropy.item() / self.accum_count:.4f}",
468502
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
469503
print("\n".join(to_log_msg))
470504
metrics = {
@@ -476,15 +510,18 @@ def _criterion(outputs, inputs):
476510
"train/advantages": self.accum_advantages.item() / self.accum_count,
477511
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
478512
"train/sample_utilization": sample_utilization,
513+
"train/entropy": self.accum_entropy.item() / self.accum_count,
479514
"train/overlength_samples_ratio": overlength_samples_ratio,
480515
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
481516
}
482517
if self.policy_loss_fn.beta > 0:
483518
metrics["train/kl"] = self.accum_kl.item() / self.accum_count
484519
if self.wandb_run is not None:
485520
self.wandb_run.log(metrics)
521+
ray.get(self.shared_signal_actor.set_signal.remote("sample_utilization", sample_utilization))
486522
self.accum_loss.zero_()
487523
self.accum_kl.zero_()
524+
self.accum_entropy.zero_()
488525
self.accum_advantages.zero_()
489526
self.accum_count = 0
490527
return loss_scalar

0 commit comments

Comments
 (0)