Skip to content

Commit 01640eb

Browse files
committed
fix bug
1 parent bd61918 commit 01640eb

File tree

2 files changed

+9
-12
lines changed

2 files changed

+9
-12
lines changed

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def launch_distributed(
8787
num_generations=num_generations,
8888
consumer_plugin_config=plugin_config,
8989
eval_dataset_config=eval_dataset_config,
90-
eval_interval=eval_interval,
90+
eval_interval=eval_interval * num_recv_per_update,
9191
evaluation_function_type=grpo_config["reward_fn_type"],
9292
eval_save_dir=eval_save_dir,
9393
)

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def __init__(
129129
else:
130130
raise ValueError(f"Unexpected backend {backend}")
131131

132-
self.consumer_pp_size = consumer_plugin_config["pp_size"] # consumer pp size
132+
self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size
133133

134134
def setup(self) -> None:
135135
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
@@ -250,14 +250,11 @@ def loop(self) -> None:
250250
# linear annealing for 1 episode, temperature from initial to 0.9
251251
if episode <= 0:
252252
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
253-
if isinstance(self.model.generate_config.temperature, dict):
254-
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
255-
"temperature"
256-
] + ratio * 0.9
257-
else:
258-
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
259-
"temperature"
260-
] + ratio * 0.9
253+
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
254+
"temperature"
255+
] + ratio * 0.9
256+
if hasattr(self.model, "sample_params"):
257+
self.model.sample_params.temperature = self.model.generate_config["temperature"]
261258

262259

263260
@ray.remote
@@ -310,8 +307,8 @@ def __init__(
310307
@torch.no_grad()
311308
def rollout(self, input_ids, attention_mask, **kwargs):
312309
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
313-
# if self.producer_idx == 1:
314-
# print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
310+
if self.producer_idx == 1:
311+
print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
315312

316313
return rollouts
317314

0 commit comments

Comments
 (0)