Skip to content

Commit 107470a

Browse files
committed
fix logging rollouts
1 parent 03b41d6 commit 107470a

File tree

5 files changed

+56
-24
lines changed

5 files changed

+56
-24
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,4 @@ applications/ColossalChat/tests/logs
166166
applications/ColossalChat/wandb
167167
applications/ColossalChat/model
168168
applications/ColossalChat/eval
169+
applications/ColossalChat/rollouts

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,16 @@ def __init__(
120120
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
121121
)
122122
# Initialize verifiable reward.
123-
response_format_tags = {
124-
"think_start": {"text": "<think>", "num_occur": 1},
125-
"think_end": {"text": "</think>", "num_occur": 1},
126-
"answer_start": {"text": "<answer>", "num_occur": 1},
127-
"answer_end": {"text": "</answer>", "num_occur": 1},
128-
}
123+
response_format_tags = (
124+
{
125+
"think_start": {"text": "<think>", "num_occur": 1},
126+
"think_end": {"text": "</think>", "num_occur": 1},
127+
"answer_start": {"text": "<answer>", "num_occur": 1},
128+
"answer_end": {"text": "</answer>", "num_occur": 1},
129+
}
130+
if grpo_config.get("reward_fn_type") == "think_answer_tags"
131+
else None
132+
)
129133
reward_model_kwargs = {
130134
k: v
131135
for k, v in grpo_config.items()

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def launch_distributed(
5555
eval_interval: int = 100,
5656
eval_save_dir: Optional[str] = None,
5757
eval_generation_config: Optional[Dict[str, Any]] = None,
58+
log_rollout_interval: int = 20,
59+
rollout_log_file: str = "./rollout_log.jsonl",
5860
):
5961
if core_algo not in ALGO_MAP:
6062
raise NotImplementedError(f"{core_algo} is not supported yet.")
@@ -98,6 +100,8 @@ def launch_distributed(
98100
project_name=project_name,
99101
run_name=run_name,
100102
wandb_group_name=wandb_group_name,
103+
log_rollout_interval=log_rollout_interval,
104+
rollout_log_file=rollout_log_file,
101105
)
102106
procs.append(producer)
103107
generate_config_consumer = copy.deepcopy(generate_config)

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import json
23
import os
34
from typing import Any, Dict, Optional
45

@@ -49,7 +50,8 @@ def __init__(
4950
project_name: str = None,
5051
run_name: str = None,
5152
wandb_group_name: str = None,
52-
wandb_log_rollout_interval: int = 20,
53+
log_rollout_interval: int = 20,
54+
rollout_log_file: str = "./rollout_log.jsonl",
5355
):
5456
self.producer_idx = producer_idx
5557
self.num_producers = num_producers
@@ -70,9 +72,16 @@ def __init__(
7072
self.eval_save_dir = eval_save_dir
7173
self.consumer_global_step = 0
7274
self.eval_mode = False
73-
self.wandb_rollout_data = []
74-
self.wandb_log_rollout_interval = wandb_log_rollout_interval
75+
self.log_rollout_interval = log_rollout_interval
7576
self.latest_rollout_log_step = -1
77+
if producer_idx == 0:
78+
if os.path.exists(rollout_log_file):
79+
raise ValueError(
80+
f"Rollout log file {rollout_log_file} already exists. Please delete it or change the name."
81+
)
82+
else:
83+
os.makedirs(os.path.dirname(rollout_log_file), exist_ok=True)
84+
self.rollout_log_file = open(rollout_log_file, "w", encoding="utf8")
7685
if self.producer_idx == 0:
7786
self.wandb_run = wandb.init(
7887
project=project_name,
@@ -320,6 +329,8 @@ def __init__(
320329
project_name: str = None,
321330
run_name: str = None,
322331
wandb_group_name: str = None,
332+
log_rollout_interval: int = 20,
333+
rollout_log_file: str = "./rollout_log.jsonl",
323334
):
324335
super().__init__(
325336
producer_idx,
@@ -342,6 +353,8 @@ def __init__(
342353
project_name=project_name,
343354
run_name=run_name,
344355
wandb_group_name=wandb_group_name,
356+
log_rollout_interval=log_rollout_interval,
357+
rollout_log_file=rollout_log_file,
345358
)
346359
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
347360
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
@@ -353,26 +366,31 @@ def __init__(
353366
def rollout(self, input_ids, attention_mask, **kwargs):
354367
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
355368
if self.producer_idx == 0 and not self.eval_mode:
356-
wandb_rollout_data = self.wandb_rollout_data + [
357-
[
358-
str(self.consumer_global_step),
359-
str(self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)),
360-
]
361-
]
362369
if (
363-
self.consumer_global_step - self.latest_rollout_log_step >= self.wandb_log_rollout_interval
370+
self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
364371
or self.latest_rollout_log_step == -1
365372
):
366-
self.wandb_rollout_data = wandb_rollout_data
367-
self.latest_rollout_log_step = self.consumer_global_step
368-
self.wandb_run.log(
369-
{
370-
"rollout/rollout_examples": wandb.Table(
371-
columns=["train_step", "rollout_examples"], data=wandb_rollout_data
373+
new_record = (
374+
json.dumps(
375+
{
376+
"train_step": self.consumer_global_step,
377+
"rollout": self.tokenizer.batch_decode(
378+
rollouts["input_ids"][:, 0], skip_special_tokens=True
379+
),
380+
}
372381
)
373-
}
374-
)
382+
+ "\n"
383+
)
384+
self.rollout_log_file.write(new_record)
385+
self.rollout_log_file.flush()
386+
self.latest_rollout_log_step = self.consumer_global_step
375387
return rollouts
376388

389+
def __del__(self):
390+
if self.producer_idx == 0:
391+
self.wandb_run.finish()
392+
if hasattr(self, "rollout_log_file"):
393+
self.rollout_log_file.close()
394+
377395
def load_state_dict(self, state_dict):
378396
self.model.load_state_dict(state_dict)

applications/ColossalChat/rl_example.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@
118118
parser.add_argument(
119119
"-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results."
120120
)
121+
parser.add_argument(
122+
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
123+
)
121124
args = parser.parse_args()
122125

123126
if args.train_minibatch_size is None:
@@ -269,4 +272,6 @@
269272
eval_interval=args.eval_interval,
270273
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
271274
eval_generation_config=eval_generation_config,
275+
log_rollout_interval=20,
276+
rollout_log_file=os.path.join(args.rollout_save_dir, args.project.replace(" ", "_") + ".jsonl"),
272277
)

0 commit comments

Comments
 (0)