Skip to content

Commit 75b818f

Browse files
authored
Merge pull request #87 from UT-Austin-RPL/recurrent_depth
Module Actor/Critic Output Heads
2 parents 919a3a9 + e8302b6 commit 75b818f

19 files changed

+407
-204
lines changed

amago/agent.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ def V(state, critic, action_dist, k) -> float:
194194
use_target_actor: If True, use a target actor to sample actions used in TD targets.
195195
Defaults to True.
196196
use_multigamma: If True, train on multiple discount horizons (:py:class:`~amago.agent.Multigammas`) in parallel. Defaults to True.
197+
actor_type: Actor MLP head for producing action distributions. Defaults to :py:class:`~amago.nets.actor_critic.Actor`.
198+
critic_type: Critic MLP head for producing Q-values. Defaults to :py:class:`~amago.nets.actor_critic.NCritics`.
197199
"""
198200

199201
def __init__(
@@ -218,6 +220,8 @@ def __init__(
218220
popart: bool = True,
219221
use_target_actor: bool = True,
220222
use_multigamma: bool = True,
223+
actor_type: Type[actor_critic.BaseActorHead] = actor_critic.Actor,
224+
critic_type: Type[actor_critic.BaseCriticHead] = actor_critic.NCritics,
221225
):
222226
super().__init__()
223227
self.obs_space = obs_space
@@ -280,17 +284,13 @@ def __init__(
280284
"discrete": self.discrete,
281285
"gammas": self.gammas,
282286
}
283-
self.critics = actor_critic.NCritics(**ac_kwargs, num_critics=num_critics)
284-
self.target_critics = actor_critic.NCritics(
285-
**ac_kwargs, num_critics=num_critics
286-
)
287-
self.maximized_critics = actor_critic.NCritics(
288-
**ac_kwargs, num_critics=num_critics
289-
)
287+
self.critics = critic_type(**ac_kwargs, num_critics=num_critics)
288+
self.target_critics = critic_type(**ac_kwargs, num_critics=num_critics)
289+
self.maximized_critics = critic_type(**ac_kwargs, num_critics=num_critics)
290290
if self.multibinary:
291291
ac_kwargs["cont_dist_kind"] = "multibinary"
292-
self.actor = actor_critic.Actor(**ac_kwargs)
293-
self.target_actor = actor_critic.Actor(**ac_kwargs)
292+
self.actor = actor_type(**ac_kwargs)
293+
self.target_actor = actor_type(**ac_kwargs)
294294
# full weight copy to targets
295295
self.hard_sync_targets()
296296

@@ -651,12 +651,14 @@ def masked_avg(x_, dim=0):
651651
].sum()
652652

653653
binary_filter = filter_ > 0
654+
masked_logp_a = logp_a[mask.bool()]
654655
stats = {
655-
"Minimum Action Logprob": logp_a.min(),
656-
"Maximum Action Logprob": logp_a.max(),
656+
"Minimum Action Logprob": masked_logp_a.min(),
657+
"Maximum Action Logprob": masked_logp_a.max(),
658+
"Mean Action Logprob": masked_logp_a.mean(),
657659
"Filter Max": filter_.max(),
658660
"Filter Min": filter_.min(),
659-
"Filter Mean": filter_.mean(),
661+
"Filter Mean": (mask * filter_).sum() / mask.sum(),
660662
"Pct. of Actions Approved by Binary FBC Filter (All Gammas)": utils.masked_avg(
661663
binary_filter, mask
662664
)
@@ -708,7 +710,7 @@ class MultiTaskAgent(Agent):
708710
709711
The combination of points 2 and 3 stresses accurate advantage estimates and motivates a change
710712
in the default value of num_actions_for_value_in_critic_loss from 1 --> 3. Arguments otherwise
711-
follow the information listed in amago.agent.Agent.
713+
follow the information listed in :py:class:`~amago.agent.Agent`.
712714
"""
713715

714716
def __init__(
@@ -733,6 +735,8 @@ def __init__(
733735
popart: bool = True,
734736
use_target_actor: bool = True,
735737
use_multigamma: bool = True,
738+
actor_type: Type[actor_critic.BaseActorHead] = actor_critic.Actor,
739+
critic_type: Type[actor_critic.BaseCriticHead] = actor_critic.NCriticsTwoHot,
736740
):
737741
super().__init__(
738742
obs_space=obs_space,
@@ -755,17 +759,9 @@ def __init__(
755759
use_multigamma=use_multigamma,
756760
fbc_filter_func=fbc_filter_func,
757761
popart=popart,
762+
actor_type=actor_type,
763+
critic_type=critic_type,
758764
)
759-
critic_kwargs = {
760-
"state_dim": self.traj_encoder.emb_dim,
761-
"action_dim": self.action_dim,
762-
"gammas": self.gammas,
763-
"num_critics": self.num_critics,
764-
}
765-
self.critics = actor_critic.NCriticsTwoHot(**critic_kwargs)
766-
self.target_critics = actor_critic.NCriticsTwoHot(**critic_kwargs)
767-
self.maximized_critics = actor_critic.NCriticsTwoHot(**critic_kwargs)
768-
self.hard_sync_targets()
769765

770766
def _sample_k_actions(self, dist, k: int):
771767
raise NotImplementedError

amago/experiment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -916,7 +916,7 @@ def compute_loss(self, batch: Batch, log_step: bool) -> dict:
916916
"Unmasked Batch Size (in Timesteps)": unmasked_batch_size,
917917
} | update_info
918918

919-
def _get_grad_norms(self):
919+
def get_grad_norms(self):
920920
"""Get gradient norms for logging."""
921921
ggn = utils.get_grad_norm
922922
pi = self.policy
@@ -952,7 +952,7 @@ def train_step(self, batch: Batch, log_step: bool):
952952
if log_step:
953953
l.update(
954954
{"Learning Rate": self.lr_schedule.get_last_lr()[0]}
955-
| self._get_grad_norms()
955+
| self.get_grad_norms()
956956
)
957957
self.optimizer.step()
958958
self.lr_schedule.step()

0 commit comments

Comments
 (0)