11import functools
2- from typing import Dict , List , Any
2+ from typing import Any
33
44import torch
55
6+ from recipe .AEnt .aent_args import AEntPPOActorConfig
7+ from recipe .AEnt .functional import gather_logprobs_clamped_entropy
68
7- from areal .api .cli_args import MicroBatchSpec , PPOActorConfig
9+ from areal .api .cli_args import MicroBatchSpec
810from areal .api .engine_api import TrainEngine
911from areal .engine .fsdp_engine import FSDPEngine
1012from areal .engine .ppo .actor import PPOActor
1113from areal .utils import stats_tracker
1214from areal .utils .data import split_padded_tensor_dict_into_mb_list
1315from areal .utils .functional import (
1416 dynamic_sampling ,
15- gather_logprobs ,
1617 gather_logprobs_entropy ,
1718 ppo_actor_loss_fn ,
18- reward_overlong_penalty ,
1919)
20- from recipe .AEnt .aent_args import AEntPPOActorConfig
21- from recipe .AEnt .functional import gather_logprobs_clamped_entropy
2220
2321
2422class AEntPPOActor (PPOActor ):
25-
2623 def __init__ (self , config : AEntPPOActorConfig , engine : TrainEngine ):
2724 super ().__init__ (config , engine )
2825 self .entropy_coeff = config .aent .entropy_coeff
@@ -39,7 +36,7 @@ def __init__(self, config: AEntPPOActorConfig, engine: TrainEngine):
3936 @stats_tracker .scope_func_wrapper ("aent_ppo_actor" )
4037 def aent_ppo_update (
4138 self , data : dict [str , Any ], global_step : int
42- ) -> List [ Dict [str , float ]]:
39+ ) -> list [ dict [str , float ]]:
4340 with stats_tracker .scope ("dynamic_sampling" ):
4441 if self .dynamic_sampling and len (data ["rewards" ]) % self .group_size == 0 :
4542 data , sampling_stat = dynamic_sampling (data , self .group_size )
@@ -156,7 +153,6 @@ def aent_ppo_update(
156153
157154
158155class FSDPAEntPPOActor (FSDPEngine ):
159-
160156 def __init__ (self , config : AEntPPOActorConfig ):
161157 super ().__init__ (config )
162158 self .actor = AEntPPOActor (config , self )
@@ -169,14 +165,14 @@ def compute_logp(self, *args, **kwargs) -> torch.Tensor | None:
169165 def compute_advantages (self , * args , ** kwargs ) -> None :
170166 self .actor .compute_advantages (* args , ** kwargs )
171167
172- def aent_ppo_update (self , * args , ** kwargs ) -> List [ Dict [str , float ]]:
168+ def aent_ppo_update (self , * args , ** kwargs ) -> list [ dict [str , float ]]:
173169 return self .actor .aent_ppo_update (* args , ** kwargs )
174170
175171
176172# AEnt regularized grpo loss
177173def aent_grpo_loss_fn (
178174 logits : torch .Tensor ,
179- input_data : Dict ,
175+ input_data : dict ,
180176 temperature : float ,
181177 eps_clip : float ,
182178 entropy_coeff : float ,
0 commit comments