Skip to content

Commit 6b06430

Browse files
committed
fix small bug
1 parent e3d56cb commit 6b06430

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
281281

282282
if self.booster.plugin.stage_manager.is_last_stage():
283283
reference_action_log_probs = memory_efficient_logprob(
284-
reference_model_outputs["outputs"]["logits"],
284+
reference_model_outputs["outputs"]["logits"] / self.generate_config["temperature"],
285285
input_ids_forward_micro_batch,
286286
num_action,
287287
shard_config=self.plugin.shard_config,
@@ -308,7 +308,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
308308
def _criterion(outputs, inputs):
309309
action_logits = outputs.logits
310310
action_log_probs = memory_efficient_logprob(
311-
action_logits,
311+
action_logits / self.generate_config["temperature"],
312312
inputs["input_ids"],
313313
num_action,
314314
shard_config=self.plugin.shard_config,
@@ -375,7 +375,7 @@ def _criterion(outputs, inputs):
375375
reference_model_logits / self.generate_config["temperature"],
376376
input_ids_forward_micro_batch,
377377
num_action,
378-
self.plugin.shard_config,
378+
shard_config=self.plugin.shard_config,
379379
)
380380
per_token_kl = (
381381
torch.exp(reference_action_log_probs - action_log_probs)

0 commit comments

Comments
 (0)