Skip to content

Commit 50070c1

Browse files
committed
move logging to producer
1 parent 47a7dc7 commit 50070c1

File tree

7 files changed

+92
-70
lines changed

7 files changed

+92
-70
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def __init__(
3636
minibatch_size: int = 1,
3737
save_interval: int = 100,
3838
save_dir: str = "./model",
39-
eval_interval: int = -1,
4039
):
4140
self.num_producers = num_producers
4241
self.num_episodes = num_episodes
@@ -52,7 +51,6 @@ def __init__(
5251
self.save_dir = save_dir
5352
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
5453
self.num_microbatches = batch_size // minibatch_size
55-
self.eval_interval = eval_interval
5654

5755
self.model_config = model_config
5856
self.plugin_config = plugin_config
@@ -94,9 +92,6 @@ def setup(self) -> None:
9492
if self.rank == 0:
9593
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
9694

97-
for i in range(self.num_producers):
98-
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_eval_statistics_{i}")
99-
10095
self.buffer = []
10196
self.recv_cnt = 0
10297

@@ -114,24 +109,6 @@ def loop(self) -> None:
114109
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar:
115110
for step in pbar:
116111
i = 0
117-
if self.eval_interval > 0 and step % self.eval_interval == 0:
118-
eval_statistics = None
119-
for r in range(self.num_producers):
120-
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
121-
local_eval_result = ray_broadcast_tensor_dict(
122-
None, src=0, device=self.device, group_name=f"sync_eval_statistics_{r}"
123-
)
124-
if eval_statistics is None:
125-
eval_statistics = local_eval_result
126-
else:
127-
eval_statistics = {
128-
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
129-
}
130-
eval_statistics = {k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
131-
if dist.get_rank() == 0:
132-
if hasattr(self, "wandb_run") and hasattr(self, "global_step"):
133-
self.wandb_run.log(eval_statistics, step=self.global_step)
134-
print(f"Eval statistics: {eval_statistics}")
135112
for _ in range(self.num_recv_per_update):
136113
# receive data from producers
137114
for r in range(self.num_producers):
@@ -214,7 +191,6 @@ def __init__(
214191
minibatch_size=1,
215192
save_interval: int = 100,
216193
save_dir="./model",
217-
eval_interval: int = -1,
218194
):
219195
super().__init__(
220196
num_producers,
@@ -231,7 +207,6 @@ def __init__(
231207
minibatch_size,
232208
save_interval,
233209
save_dir,
234-
eval_interval,
235210
)
236211
path = model_config.pop("path")
237212
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ def __init__(
3434
plugin_config,
3535
minibatch_size=1,
3636
num_generations=8,
37-
use_wandb=True,
3837
generate_config=None,
3938
grpo_config={},
40-
project_name=None,
4139
save_interval: int = 100,
4240
save_dir="./model",
43-
eval_interval: int = -1,
41+
project_name: str = None,
42+
run_name: str = None,
43+
wandb_group_name: str = None,
4444
):
4545
print(f"Using GRPO config: {grpo_config}")
4646
if grpo_config.get("loss_variation", "sample_level") == "token_level":
@@ -73,7 +73,6 @@ def __init__(
7373
minibatch_size,
7474
save_interval=save_interval,
7575
save_dir=save_dir,
76-
eval_interval=eval_interval,
7776
)
7877
path = model_config.pop("path")
7978
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@@ -93,6 +92,9 @@ def __init__(
9392
self.project_name = project_name
9493
self.effective_sample_count = 0
9594
self.total_sample_count = 0
95+
self.project_name = project_name
96+
self.run_name = run_name
97+
self.wandb_group_name = wandb_group_name
9698

9799
self.policy_loss_fn = PolicyLoss(
98100
clip_eps_low=grpo_config.get("clip_eps_low", 0.2),
@@ -143,7 +145,6 @@ def __init__(
143145
**reward_model_kwargs,
144146
)
145147
self.global_step = 0
146-
self.use_wandb = use_wandb
147148

148149
self.lr_scheduler = CosineAnnealingWarmupLR(
149150
optimizer=self.optimizer,
@@ -154,13 +155,16 @@ def __init__(
154155

155156
def setup(self):
156157
super().setup()
157-
if self.use_wandb and (
158-
(not self.plugin.pp_size > 1 and self.rank == 0)
159-
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
158+
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
159+
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
160160
):
161-
# Initialize wandb.
162-
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
163-
self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name)
161+
self.wandb_run = wandb.init(
162+
project=self.project_name,
163+
sync_tensorboard=True,
164+
dir="./wandb",
165+
name=self.run_name,
166+
group=self.wandb_group_name,
167+
)
164168

165169
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
166170
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
@@ -512,8 +516,8 @@ def _criterion(outputs, inputs):
512516
}
513517
if self.policy_loss_fn.beta > 0:
514518
metrics["train/kl"] = self.accum_kl.item() / self.accum_count
515-
516-
self.wandb_run.log(metrics)
519+
if self.wandb_run is not None:
520+
self.wandb_run.log(metrics)
517521
self.accum_loss.zero_()
518522
self.accum_reward.zero_()
519523
self.accum_ans_acc.zero_()

applications/ColossalChat/coati/distributed/inference_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
238238
log_probs.append(p)
239239

240240
# pad them
241-
max_len = self.generate_config.max_tokens
241+
max_len = self.sample_params.max_tokens
242242
action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype)
243243

244244
for i, new_token_ids in enumerate(out_tokens):

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import uuid
23
from typing import Any, Dict, Optional
34

45
import ray
@@ -53,6 +54,7 @@ def launch_distributed(
5354
eval_dataset_config: Optional[Dict[str, Any]] = None,
5455
eval_interval: int = 100,
5556
eval_save_dir: Optional[str] = None,
57+
eval_generation_config: Optional[Dict[str, Any]] = None,
5658
):
5759

5860
if core_algo not in ALGO_MAP:
@@ -69,6 +71,9 @@ def launch_distributed(
6971
num_update_per_episode = num_samples // global_inference_batch_size
7072
num_recv_per_update = inference_batch_size // inference_microbatch_size
7173

74+
run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
75+
wandb_group_name = str(uuid.uuid4())
76+
7277
procs = []
7378
for i in range(num_producers):
7479
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
@@ -90,6 +95,10 @@ def launch_distributed(
9095
eval_interval=eval_interval,
9196
evaluation_function_type=grpo_config["reward_fn_type"],
9297
eval_save_dir=eval_save_dir,
98+
eval_generation_config=eval_generation_config,
99+
project_name=project_name,
100+
run_name=run_name,
101+
wandb_group_name=wandb_group_name,
93102
)
94103
procs.append(producer)
95104
generate_config_consumer = copy.deepcopy(generate_config)
@@ -115,10 +124,11 @@ def launch_distributed(
115124
generate_config=generate_config_consumer,
116125
grpo_config=grpo_config,
117126
num_generations=num_generations,
118-
project_name=project_name,
119127
save_interval=save_interval,
120128
save_dir=save_dir,
121-
eval_interval=eval_interval,
129+
project_name=project_name,
130+
run_name=run_name,
131+
wandb_group_name=wandb_group_name,
122132
)
123133
procs.append(consumer)
124134
ray.get([p.setup.remote() for p in procs])

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 58 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,19 @@
66
import ray.util.collective as cc
77
import torch
88
import tqdm
9+
import wandb
910
from coati.dataset.loader import RawConversationDataset
1011
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
12+
from ray.util.collective import allreduce
13+
from ray.util.collective.types import Backend, ReduceOp
1114
from torch.utils.data import DataLoader, DistributedSampler
1215
from transformers import AutoTokenizer
1316

1417
from colossalai.utils import get_current_device
1518

1619
from .comm import ray_broadcast_tensor_dict
1720
from .inference_backend import BACKEND_MAP
18-
from .utils import pre_send, safe_write_jsonl
21+
from .utils import pre_send, safe_append_to_jsonl_file
1922

2023
try:
2124
from vllm import SamplingParams
@@ -43,6 +46,9 @@ def __init__(
4346
eval_interval=-1, # disable evaluation
4447
evaluation_function_type="think_answer_tags",
4548
eval_save_dir: str = "./eval",
49+
project_name: str = None,
50+
run_name: str = None,
51+
wandb_group_name: str = None,
4652
):
4753
self.producer_idx = producer_idx
4854
self.num_producers = num_producers
@@ -61,6 +67,14 @@ def __init__(
6167
self.eval_interval = eval_interval
6268
self.eval_save_dir = eval_save_dir
6369
self.consumer_global_step = 0
70+
if self.producer_idx == 0:
71+
self.wandb_run = wandb.init(
72+
project=project_name,
73+
sync_tensorboard=True,
74+
dir="./wandb",
75+
name=run_name + "_eval",
76+
group=wandb_group_name,
77+
)
6478

6579
if os.path.exists(self.eval_save_dir):
6680
raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.")
@@ -132,13 +146,18 @@ def __init__(
132146
self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size
133147

134148
def setup(self) -> None:
149+
cc.init_collective_group(
150+
world_size=self.num_producers,
151+
rank=self.producer_idx,
152+
backend=Backend.NCCL,
153+
group_name="producer_group",
154+
)
135155
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
136156
if self.consumer_pp_size > 1:
137157
for i in range(self.consumer_pp_size):
138158
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}")
139159
else:
140160
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
141-
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_eval_statistics_{self.producer_idx}")
142161

143162
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
144163
raise NotImplementedError
@@ -160,13 +179,14 @@ def loop(self) -> None:
160179
break
161180
if self.eval_interval > 0 and self.eval_dataset_config is not None:
162181
if i % self.eval_interval == 0:
163-
eval_statistics = {}
182+
to_log_msg = {}
164183
for eval_task_name in self.eval_dataloaders:
165-
print(
166-
f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}"
167-
)
184+
if self.producer_idx == 0:
185+
print(
186+
f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}"
187+
)
168188
eval_results = []
169-
eval_statistics[eval_task_name] = torch.zeros(2, device=self.device)
189+
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
170190
for eval_batch in tqdm.tqdm(
171191
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
172192
):
@@ -182,24 +202,27 @@ def loop(self) -> None:
182202
for m in range(eval_outputs["input_ids"].size(0))
183203
for n in range(eval_outputs["input_ids"].size(1))
184204
]
185-
eval_statistics[eval_task_name][0] += len(
186-
[res for res in eval_results if res["ans_valid"] == 1]
205+
eval_statistics_tensor[0] += len([res for res in eval_results if res["ans_valid"] == 1])
206+
eval_statistics_tensor[1] += len(eval_results)
207+
allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_group")
208+
to_log_msg[f"eval/{eval_task_name}"] = (
209+
eval_statistics_tensor[0].item() / eval_statistics_tensor[1].item()
187210
)
188-
eval_statistics[eval_task_name][1] += len(eval_results)
211+
if self.producer_idx == 0:
212+
print(
213+
f"[P{self.producer_idx}]: Accuracy on {eval_task_name}: {to_log_msg[f'eval/{eval_task_name}']}"
214+
)
189215
# save eval results
190-
result_file_name = os.path.join(
191-
self.eval_save_dir,
192-
f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
216+
safe_append_to_jsonl_file(
217+
os.path.join(
218+
self.eval_save_dir,
219+
f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
220+
),
221+
eval_results,
193222
)
194-
# delete the file if it exists
195-
safe_write_jsonl(result_file_name, eval_results)
196-
print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}")
197-
ray_broadcast_tensor_dict(
198-
eval_statistics,
199-
src=0,
200-
device=self.device,
201-
group_name=f"sync_eval_statistics_{self.producer_idx}",
202-
)
223+
224+
if self.producer_idx == 0:
225+
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
203226
outputs = self.rollout(**batch)
204227

205228
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
@@ -248,12 +271,11 @@ def loop(self) -> None:
248271
# linear annealing for 1 episode, temperature from initial to 0.9
249272
if episode <= 0:
250273
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
251-
if isinstance(self.model.generate_config.temperature, dict):
252-
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
253-
"temperature"
254-
] + ratio * 0.9
255-
else:
256-
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
274+
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
275+
"temperature"
276+
] + ratio * 0.9
277+
if isinstance(self.model, BACKEND_MAP["vllm"]):
278+
self.model.sample_params.temperature = (1 - ratio) * self.generate_config[
257279
"temperature"
258280
] + ratio * 0.9
259281

@@ -280,6 +302,10 @@ def __init__(
280302
eval_interval=-1, # disable evaluation
281303
evaluation_function_type="think_answer_tags",
282304
eval_save_dir: str = "./eval",
305+
eval_generation_config={},
306+
project_name: str = None,
307+
run_name: str = None,
308+
wandb_group_name: str = None,
283309
):
284310
super().__init__(
285311
producer_idx,
@@ -299,10 +325,14 @@ def __init__(
299325
eval_interval=eval_interval,
300326
evaluation_function_type=evaluation_function_type,
301327
eval_save_dir=eval_save_dir,
328+
project_name=project_name,
329+
run_name=run_name,
330+
wandb_group_name=wandb_group_name,
302331
)
303332
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
304333
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
305334
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
335+
self.eval_generation_config.update(eval_generation_config)
306336
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
307337

308338
@torch.no_grad()

applications/ColossalChat/coati/distributed/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.
135135
return tensor.sum(dim=dim)
136136

137137

138-
def safe_write_jsonl(file_path, data):
138+
def safe_append_to_jsonl_file(file_path, data):
139139
with FileLock(file_path + ".lock"):
140140
# Ensure file exists
141141
os.makedirs(os.path.dirname(file_path), exist_ok=True)

0 commit comments

Comments
 (0)