Skip to content

Commit 203dfb1

Browse files
committed
address conversation
1 parent 6abffb9 commit 203dfb1

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def launch_distributed(
6666

6767
dataset_path = train_dataset_config["path"]
6868
num_samples = get_jsonl_size_fast(dataset_path)
69-
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer
69+
global_inference_batch_size = inference_batch_size * num_producers
7070
num_update_per_episode = num_samples // global_inference_batch_size
7171
num_recv_per_update = inference_batch_size // inference_microbatch_size
7272

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def loop(self) -> None:
187187
for eval_task_name in self.eval_dataloaders:
188188
if self.producer_idx == 0:
189189
print(
190-
f"[P{self.producer_idx}] Evaluate episode {episode} step {self.consumer_global_step} on task {eval_task_name}"
190+
f"[P{self.producer_idx}] Evaluate model at training step {self.consumer_global_step} on task {eval_task_name}"
191191
)
192192
eval_results = []
193193
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
@@ -220,7 +220,7 @@ def loop(self) -> None:
220220
safe_append_to_jsonl_file(
221221
os.path.join(
222222
self.eval_save_dir,
223-
f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
223+
f"{eval_task_name}_training_step_{self.consumer_global_step}.jsonl",
224224
),
225225
eval_results,
226226
)

applications/ColossalChat/rl_example.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,13 @@
104104
choices=["think_answer_tags", "boxed"],
105105
help="Reward type for GRPO.",
106106
)
107-
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.")
107+
parser.add_argument(
108+
"-ei",
109+
"--eval-interval",
110+
type=int,
111+
default=100,
112+
help="Interval for evaluation. Evaluate every ei training steps.",
113+
)
108114

109115
# Logging/Checkpointing parameters
110116
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")

0 commit comments

Comments
 (0)