Skip to content

Commit 777388b

Browse files
committed
upd on calculate_entropy
1 parent 09b6fcf commit 777388b

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

trinity/trainer/verl/dp_actor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from verl.workers.actor.dp_actor import DataParallelPPOActor as DPActor
3333

3434
from trinity.algorithm import ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN
35+
from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import DummyEntropyLossFn
3536
from trinity.algorithm.kl_fn.kl_fn import DummyKLFn
3637
from trinity.algorithm.utils import prefix_metrics
3738
from trinity.common.config import AlgorithmConfig
@@ -232,8 +233,11 @@ def update_policy(self, data: DataProto):
232233
assert response_mask.shape == attention_mask[:, -response_length:].shape
233234

234235
# all return: (bsz, response_length)
236+
calculate_entropy = self.entropy_loss_fn != DummyEntropyLossFn
235237
entropy, log_prob = self._forward_micro_batch(
236-
micro_batch=data, temperature=temperature, calculate_entropy=True
238+
micro_batch=data,
239+
temperature=temperature,
240+
calculate_entropy=calculate_entropy,
237241
)
238242

239243
kwargs = {

0 commit comments

Comments
 (0)