File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed
Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change 3232from verl .workers .actor .dp_actor import DataParallelPPOActor as DPActor
3333
3434from trinity .algorithm import ENTROPY_LOSS_FN , KL_FN , POLICY_LOSS_FN
35+ from trinity .algorithm .entropy_loss_fn .entropy_loss_fn import DummyEntropyLossFn
3536from trinity .algorithm .kl_fn .kl_fn import DummyKLFn
3637from trinity .algorithm .utils import prefix_metrics
3738from trinity .common .config import AlgorithmConfig
@@ -232,8 +233,11 @@ def update_policy(self, data: DataProto):
232233 assert response_mask .shape == attention_mask [:, - response_length :].shape
233234
234235 # all return: (bsz, response_length)
236+ calculate_entropy = self .entropy_loss_fn != DummyEntropyLossFn
235237 entropy , log_prob = self ._forward_micro_batch (
236- micro_batch = data , temperature = temperature , calculate_entropy = True
238+ micro_batch = data ,
239+ temperature = temperature ,
240+ calculate_entropy = calculate_entropy ,
237241 )
238242
239243 kwargs = {
You can’t perform that action at this time.
0 commit comments