Skip to content

Commit b314da1

Browse files
committed
fix small bug
1 parent 245c8c2 commit b314da1

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
294294

295295
if self.booster.plugin.stage_manager.is_last_stage():
296296
reference_action_log_probs = memory_efficient_logprob(
297-
reference_model_outputs["outputs"]["logits"],
297+
reference_model_outputs["outputs"]["logits"] / self.generate_config["temperature"],
298298
input_ids_forward_micro_batch,
299299
num_action,
300300
shard_config=self.plugin.shard_config,
@@ -321,7 +321,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
321321
def _criterion(outputs, inputs):
322322
action_logits = outputs.logits
323323
action_log_probs = memory_efficient_logprob(
324-
action_logits,
324+
action_logits / self.generate_config["temperature"],
325325
inputs["input_ids"],
326326
num_action,
327327
shard_config=self.plugin.shard_config,
@@ -388,7 +388,7 @@ def _criterion(outputs, inputs):
388388
reference_model_logits / self.generate_config["temperature"],
389389
input_ids_forward_micro_batch,
390390
num_action,
391-
self.plugin.shard_config,
391+
shard_config=self.plugin.shard_config,
392392
)
393393
per_token_kl = (
394394
torch.exp(reference_action_log_probs - action_log_probs)

0 commit comments

Comments
 (0)