Skip to content

Commit 116621d

Browse files
committed
merge reward and eval
2 parents 11a5854 + 107470a commit 116621d

File tree

7 files changed

+221
-71
lines changed

7 files changed

+221
-71
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,4 @@ applications/ColossalChat/tests/logs
166166
applications/ColossalChat/wandb
167167
applications/ColossalChat/model
168168
applications/ColossalChat/eval
169+
applications/ColossalChat/rollouts

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ def loop(self) -> None:
113113
) as pbar:
114114
for step in pbar:
115115
i = 0
116-
allow_sync_model = True
117116
for _ in range(self.num_recv_per_update):
118117
# receive data from producers
119118
for r in range(self.num_producers):
@@ -140,7 +139,6 @@ def loop(self) -> None:
140139
loss = self.step(i, pbar, **batch)
141140
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
142141
if loss is not None:
143-
allow_sync_model = True
144142
pbar.set_postfix({"loss": loss})
145143
i += 1
146144
if self.lr_scheduler is not None:
@@ -154,31 +152,29 @@ def loop(self) -> None:
154152
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
155153

156154
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
157-
if allow_sync_model:
158-
if self.pp_size > 1:
159-
print(
160-
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
155+
if self.pp_size > 1:
156+
print(
157+
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
158+
)
159+
else:
160+
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
161+
torch.cuda.empty_cache()
162+
state_dict = self.state_dict()
163+
if self.pp_size > 1:
164+
if self.tp_rank == 0 and self.dp_rank == 0:
165+
ray_broadcast_tensor_dict(
166+
state_dict,
167+
src=self.num_producers,
168+
device=self.device,
169+
group_name=f"sync_model_{self.pp_rank}",
170+
)
171+
else:
172+
if self.rank == 0:
173+
ray_broadcast_tensor_dict(
174+
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
161175
)
162-
else:
163-
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
164-
torch.cuda.empty_cache()
165-
state_dict = self.state_dict()
166-
if self.pp_size > 1:
167-
if self.tp_rank == 0 and self.dp_rank == 0:
168-
ray_broadcast_tensor_dict(
169-
state_dict,
170-
src=self.num_producers,
171-
device=self.device,
172-
group_name=f"sync_model_{self.pp_rank}",
173-
)
174-
else:
175-
if self.rank == 0:
176-
ray_broadcast_tensor_dict(
177-
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
178-
)
179-
del state_dict
180-
torch.cuda.empty_cache()
181-
allow_sync_model = True
176+
del state_dict
177+
torch.cuda.empty_cache()
182178

183179

184180
@ray.remote

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,20 @@ def __init__(
120120
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
121121
)
122122
# Initialize verifiable reward.
123-
response_format_tags = {
124-
"think_start": {"text": "<think>", "num_occur": 1},
125-
"think_end": {"text": "</think>", "num_occur": 1},
126-
"answer_start": {"text": "<answer>", "num_occur": 1},
127-
"answer_end": {"text": "</answer>", "num_occur": 1},
128-
}
123+
response_format_tags = (
124+
{
125+
"think_start": {"text": "<think>", "num_occur": 1},
126+
"think_end": {"text": "</think>", "num_occur": 1},
127+
"answer_start": {"text": "<answer>", "num_occur": 1},
128+
"answer_end": {"text": "</answer>", "num_occur": 1},
129+
}
130+
if grpo_config.get("reward_fn_type") == "think_answer_tags"
131+
else None
132+
)
129133
reward_model_kwargs = {
130-
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
134+
k: v
135+
for k, v in grpo_config.items()
136+
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
131137
}
132138
self.reward_model = VerifiableReward(
133139
reward_fns=[

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def launch_distributed(
5555
eval_interval: int = 100,
5656
eval_save_dir: Optional[str] = None,
5757
eval_generation_config: Optional[Dict[str, Any]] = None,
58+
log_rollout_interval: int = 20,
59+
rollout_log_file: str = "./rollout_log.jsonl",
5860
):
5961
if core_algo not in ALGO_MAP:
6062
raise NotImplementedError(f"{core_algo} is not supported yet.")
@@ -98,6 +100,8 @@ def launch_distributed(
98100
project_name=project_name,
99101
run_name=run_name,
100102
wandb_group_name=wandb_group_name,
103+
log_rollout_interval=log_rollout_interval,
104+
rollout_log_file=rollout_log_file,
101105
)
102106
procs.append(producer)
103107
generate_config_consumer = copy.deepcopy(generate_config)

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import json
23
import os
34
from typing import Any, Dict, Optional
45

@@ -49,6 +50,8 @@ def __init__(
4950
project_name: str = None,
5051
run_name: str = None,
5152
wandb_group_name: str = None,
53+
log_rollout_interval: int = 20,
54+
rollout_log_file: str = "./rollout_log.jsonl",
5255
):
5356
self.producer_idx = producer_idx
5457
self.num_producers = num_producers
@@ -58,7 +61,7 @@ def __init__(
5861
self.microbatch_size = microbatch_size
5962
assert batch_size % microbatch_size == 0
6063
self.num_microbatches = batch_size // microbatch_size
61-
self.lastest_eval_step = -1
64+
self.latest_eval_step = -1
6265

6366
self.train_dataset_config = train_dataset_config
6467
self.model_config = model_config
@@ -68,6 +71,17 @@ def __init__(
6871
self.eval_interval = eval_interval
6972
self.eval_save_dir = eval_save_dir
7073
self.consumer_global_step = 0
74+
self.eval_mode = False
75+
self.log_rollout_interval = log_rollout_interval
76+
self.latest_rollout_log_step = -1
77+
if producer_idx == 0:
78+
if os.path.exists(rollout_log_file):
79+
raise ValueError(
80+
f"Rollout log file {rollout_log_file} already exists. Please delete it or change the name."
81+
)
82+
else:
83+
os.makedirs(os.path.dirname(rollout_log_file), exist_ok=True)
84+
self.rollout_log_file = open(rollout_log_file, "w", encoding="utf8")
7185
if self.producer_idx == 0:
7286
self.wandb_run = wandb.init(
7387
project=project_name,
@@ -77,7 +91,7 @@ def __init__(
7791
group=wandb_group_name,
7892
)
7993

80-
if os.path.exists(self.eval_save_dir):
94+
if os.path.exists(self.eval_save_dir) and self.eval_interval > 0:
8195
raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.")
8296

8397
# init tokenizer
@@ -180,14 +194,15 @@ def loop(self) -> None:
180194
break
181195
if self.eval_interval > 0 and self.eval_dataset_config is not None:
182196
if (
183-
self.consumer_global_step % self.eval_interval == 0
184-
and self.consumer_global_step > self.lastest_eval_step
185-
):
197+
self.consumer_global_step - self.latest_eval_step >= self.eval_interval
198+
and self.consumer_global_step > self.latest_eval_step
199+
) or self.latest_eval_step == -1:
186200
to_log_msg = {}
201+
self.eval_mode = True
187202
for eval_task_name in self.eval_dataloaders:
188203
if self.producer_idx == 0:
189204
print(
190-
f"[P{self.producer_idx}] Evaluate consumer step {self.consumer_global_step} on task {eval_task_name}"
205+
f"[P{self.producer_idx}] Evaluate model at training step {self.consumer_global_step} on task {eval_task_name}"
191206
)
192207
eval_results = []
193208
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
@@ -220,14 +235,15 @@ def loop(self) -> None:
220235
safe_append_to_jsonl_file(
221236
os.path.join(
222237
self.eval_save_dir,
223-
f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
238+
f"{eval_task_name}_training_step_{self.consumer_global_step}.jsonl",
224239
),
225240
eval_results,
226241
)
227242

228243
if self.producer_idx == 0:
229244
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
230-
self.lastest_eval_step = self.consumer_global_step
245+
self.eval_mode = False
246+
self.latest_eval_step = self.consumer_global_step
231247
outputs = self.rollout(**batch)
232248

233249
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
@@ -256,6 +272,8 @@ def loop(self) -> None:
256272
state_dict = ray_broadcast_tensor_dict(
257273
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
258274
)
275+
if "consumer_global_step" in state_dict:
276+
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
259277
self.load_state_dict(state_dict)
260278
else:
261279
print(
@@ -311,6 +329,8 @@ def __init__(
311329
project_name: str = None,
312330
run_name: str = None,
313331
wandb_group_name: str = None,
332+
log_rollout_interval: int = 20,
333+
rollout_log_file: str = "./rollout_log.jsonl",
314334
):
315335
super().__init__(
316336
producer_idx,
@@ -333,6 +353,8 @@ def __init__(
333353
project_name=project_name,
334354
run_name=run_name,
335355
wandb_group_name=wandb_group_name,
356+
log_rollout_interval=log_rollout_interval,
357+
rollout_log_file=rollout_log_file,
336358
)
337359
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
338360
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
@@ -343,10 +365,32 @@ def __init__(
343365
@torch.no_grad()
344366
def rollout(self, input_ids, attention_mask, **kwargs):
345367
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
346-
# if self.producer_idx == 1:
347-
# print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
348-
368+
if self.producer_idx == 0 and not self.eval_mode:
369+
if (
370+
self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
371+
or self.latest_rollout_log_step == -1
372+
):
373+
new_record = (
374+
json.dumps(
375+
{
376+
"train_step": self.consumer_global_step,
377+
"rollout": self.tokenizer.batch_decode(
378+
rollouts["input_ids"][:, 0], skip_special_tokens=True
379+
),
380+
}
381+
)
382+
+ "\n"
383+
)
384+
self.rollout_log_file.write(new_record)
385+
self.rollout_log_file.flush()
386+
self.latest_rollout_log_step = self.consumer_global_step
349387
return rollouts
350388

389+
def __del__(self):
390+
if self.producer_idx == 0:
391+
self.wandb_run.finish()
392+
if hasattr(self, "rollout_log_file"):
393+
self.rollout_log_file.close()
394+
351395
def load_state_dict(self, state_dict):
352396
self.model.load_state_dict(state_dict)

0 commit comments

Comments
 (0)