@@ -194,6 +194,8 @@ def V(state, critic, action_dist, k) -> float:
194194 use_target_actor: If True, use a target actor to sample actions used in TD targets.
195195 Defaults to True.
196196 use_multigamma: If True, train on multiple discount horizons (:py:class:`~amago.agent.Multigammas`) in parallel. Defaults to True.
197+ actor_type: Actor MLP head for producing action distributions. Defaults to :py:class:`~amago.nets.actor_critic.Actor`.
198+ critic_type: Critic MLP head for producing Q-values. Defaults to :py:class:`~amago.nets.actor_critic.NCritics`.
197199 """
198200
199201 def __init__ (
@@ -218,6 +220,8 @@ def __init__(
218220 popart : bool = True ,
219221 use_target_actor : bool = True ,
220222 use_multigamma : bool = True ,
223+ actor_type : Type [actor_critic .BaseActorHead ] = actor_critic .Actor ,
224+ critic_type : Type [actor_critic .BaseCriticHead ] = actor_critic .NCritics ,
221225 ):
222226 super ().__init__ ()
223227 self .obs_space = obs_space
@@ -280,17 +284,13 @@ def __init__(
280284 "discrete" : self .discrete ,
281285 "gammas" : self .gammas ,
282286 }
283- self .critics = actor_critic .NCritics (** ac_kwargs , num_critics = num_critics )
284- self .target_critics = actor_critic .NCritics (
285- ** ac_kwargs , num_critics = num_critics
286- )
287- self .maximized_critics = actor_critic .NCritics (
288- ** ac_kwargs , num_critics = num_critics
289- )
287+ self .critics = critic_type (** ac_kwargs , num_critics = num_critics )
288+ self .target_critics = critic_type (** ac_kwargs , num_critics = num_critics )
289+ self .maximized_critics = critic_type (** ac_kwargs , num_critics = num_critics )
290290 if self .multibinary :
291291 ac_kwargs ["cont_dist_kind" ] = "multibinary"
292- self .actor = actor_critic . Actor (** ac_kwargs )
293- self .target_actor = actor_critic . Actor (** ac_kwargs )
292+ self .actor = actor_type (** ac_kwargs )
293+ self .target_actor = actor_type (** ac_kwargs )
294294 # full weight copy to targets
295295 self .hard_sync_targets ()
296296
@@ -651,12 +651,14 @@ def masked_avg(x_, dim=0):
651651 ].sum ()
652652
653653 binary_filter = filter_ > 0
654+ masked_logp_a = logp_a [mask .bool ()]
654655 stats = {
655- "Minimum Action Logprob" : logp_a .min (),
656- "Maximum Action Logprob" : logp_a .max (),
656+ "Minimum Action Logprob" : masked_logp_a .min (),
657+ "Maximum Action Logprob" : masked_logp_a .max (),
658+ "Mean Action Logprob" : masked_logp_a .mean (),
657659 "Filter Max" : filter_ .max (),
658660 "Filter Min" : filter_ .min (),
659- "Filter Mean" : filter_ . mean (),
661+ "Filter Mean" : ( mask * filter_ ). sum () / mask . sum (),
660662 "Pct. of Actions Approved by Binary FBC Filter (All Gammas)" : utils .masked_avg (
661663 binary_filter , mask
662664 )
@@ -708,7 +710,7 @@ class MultiTaskAgent(Agent):
708710
709711 The combination of points 2 and 3 stresses accurate advantage estimates and motivates a change
710712 in the default value of num_actions_for_value_in_critic_loss from 1 --> 3. Arguments otherwise
711- follow the information listed in amago.agent.Agent.
713+ follow the information listed in :py:class:`~ amago.agent.Agent` .
712714 """
713715
714716 def __init__ (
@@ -733,6 +735,8 @@ def __init__(
733735 popart : bool = True ,
734736 use_target_actor : bool = True ,
735737 use_multigamma : bool = True ,
738+ actor_type : Type [actor_critic .BaseActorHead ] = actor_critic .Actor ,
739+ critic_type : Type [actor_critic .BaseCriticHead ] = actor_critic .NCriticsTwoHot ,
736740 ):
737741 super ().__init__ (
738742 obs_space = obs_space ,
@@ -755,17 +759,9 @@ def __init__(
755759 use_multigamma = use_multigamma ,
756760 fbc_filter_func = fbc_filter_func ,
757761 popart = popart ,
762+ actor_type = actor_type ,
763+ critic_type = critic_type ,
758764 )
759- critic_kwargs = {
760- "state_dim" : self .traj_encoder .emb_dim ,
761- "action_dim" : self .action_dim ,
762- "gammas" : self .gammas ,
763- "num_critics" : self .num_critics ,
764- }
765- self .critics = actor_critic .NCriticsTwoHot (** critic_kwargs )
766- self .target_critics = actor_critic .NCriticsTwoHot (** critic_kwargs )
767- self .maximized_critics = actor_critic .NCriticsTwoHot (** critic_kwargs )
768- self .hard_sync_targets ()
769765
770766 def _sample_k_actions (self , dist , k : int ):
771767 raise NotImplementedError
0 commit comments