From 74240c4a26cb30bb3d00b6aa3a56c8ba2ff0a037 Mon Sep 17 00:00:00 2001 From: Andrew Choi Date: Mon, 12 May 2025 14:10:53 -0700 Subject: [PATCH 1/3] [Bug Fix] Enable hybrid SAC training --- alf/algorithms/sac_algorithm.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/alf/algorithms/sac_algorithm.py b/alf/algorithms/sac_algorithm.py index f2363bd41..bd1176afa 100644 --- a/alf/algorithms/sac_algorithm.py +++ b/alf/algorithms/sac_algorithm.py @@ -27,13 +27,13 @@ from alf.algorithms.config import TrainerConfig from alf.algorithms.off_policy_algorithm import OffPolicyAlgorithm from alf.algorithms.one_step_loss import OneStepTDLoss -from alf.algorithms.rl_algorithm import RLAlgorithm -from alf.data_structures import TimeStep, Experience, LossInfo, namedtuple +from alf.data_structures import TimeStep, LossInfo, namedtuple, \ + BasicRolloutInfo from alf.data_structures import AlgStep, StepType from alf.nest import nest import alf.nest.utils as nest_utils from alf.networks import ActorDistributionNetwork, CriticNetwork -from alf.networks import QNetwork, QRNNNetwork +from alf.networks import QNetwork from alf.tensor_specs import TensorSpec, BoundedTensorSpec from alf.utils import losses, common, dist_utils, math_ops from alf.utils.normalizers import ScalarAdaptiveNormalizer @@ -847,6 +847,10 @@ def _select_q_value(self, action, q_values): def _critic_train_step(self, observation, target_observation, state: SacCriticState, rollout_info: SacInfo, action, action_distribution): + + if isinstance(rollout_info, BasicRolloutInfo): + rollout_info = rollout_info.rl + critics, critics_state = self._compute_critics( self._critic_networks, observation, From 2c7e7cfc610f03b07ad458b4b62d0ac0af623740 Mon Sep 17 00:00:00 2001 From: Andrew Choi Date: Mon, 12 May 2025 14:29:54 -0700 Subject: [PATCH 2/3] address comments --- alf/algorithms/algorithm.py | 4 +++- alf/algorithms/sac_algorithm.py | 12 +++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/alf/algorithms/algorithm.py b/alf/algorithms/algorithm.py index 19a6db97b..f2e069332 100644 --- a/alf/algorithms/algorithm.py +++ b/alf/algorithms/algorithm.py @@ -29,7 +29,7 @@ from torch.nn.modules.module import _IncompatibleKeys, _addindent import alf -from alf.data_structures import AlgStep, LossInfo, StepType, TimeStep +from alf.data_structures import AlgStep, LossInfo, StepType, TimeStep, BasicRolloutInfo from alf.experience_replayers.replay_buffer import BatchInfo, ReplayBuffer from alf.optimizers.utils import GradientNoiseScaleEstimator from alf.utils.checkpoint_utils import (is_checkpoint_enabled, @@ -1368,6 +1368,8 @@ def train_step_offline(self, inputs, state, rollout_info, pre_train=False): customized training. """ try: + if isinstance(rollout_info, BasicRolloutInfo): + rollout_info = rollout_info.rl return self.train_step(inputs, state, rollout_info) except: # the default train_step is not compatible with the diff --git a/alf/algorithms/sac_algorithm.py b/alf/algorithms/sac_algorithm.py index bd1176afa..326413c6c 100644 --- a/alf/algorithms/sac_algorithm.py +++ b/alf/algorithms/sac_algorithm.py @@ -28,7 +28,7 @@ from alf.algorithms.off_policy_algorithm import OffPolicyAlgorithm from alf.algorithms.one_step_loss import OneStepTDLoss from alf.data_structures import TimeStep, LossInfo, namedtuple, \ - BasicRolloutInfo + BasicRLInfo from alf.data_structures import AlgStep, StepType from alf.nest import nest import alf.nest.utils as nest_utils @@ -845,11 +845,9 @@ def _select_q_value(self, action, q_values): return q_values.gather(2, action).squeeze(2) def _critic_train_step(self, observation, target_observation, - state: SacCriticState, rollout_info: SacInfo, - action, action_distribution): - - if isinstance(rollout_info, BasicRolloutInfo): - rollout_info = rollout_info.rl + state: SacCriticState, + rollout_info: SacInfo | BasicRLInfo, action, + action_distribution): critics, critics_state = self._compute_critics( self._critic_networks, @@ -901,7 +899,7 @@ def _alpha_train_step(self, log_pi): return sum(nest.flatten(alpha_loss)) def train_step(self, inputs: TimeStep, state: SacState, - rollout_info: SacInfo): + rollout_info: SacInfo | BasicRLInfo): assert not self._is_eval self._training_started = True if self._target_repr_alg is not None: From 9db5652970cc8ab907202b6bf94584fb357b02be Mon Sep 17 00:00:00 2001 From: Andrew Choi Date: Tue, 13 May 2025 13:48:06 -0700 Subject: [PATCH 3/3] address comments --- alf/algorithms/algorithm.py | 5 +++++ alf/algorithms/sac_algorithm.py | 10 ++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/alf/algorithms/algorithm.py b/alf/algorithms/algorithm.py index f2e069332..0fae27277 100644 --- a/alf/algorithms/algorithm.py +++ b/alf/algorithms/algorithm.py @@ -1369,6 +1369,11 @@ def train_step_offline(self, inputs, state, rollout_info, pre_train=False): """ try: if isinstance(rollout_info, BasicRolloutInfo): + logging.log_first_n( + logging.WARNING, + "Detected offline buffer training without Agent wrapper. " + "For best compatibility, it is advised to use the Agent wrapper.", + n=1) rollout_info = rollout_info.rl return self.train_step(inputs, state, rollout_info) except: diff --git a/alf/algorithms/sac_algorithm.py b/alf/algorithms/sac_algorithm.py index 326413c6c..5044ba3e9 100644 --- a/alf/algorithms/sac_algorithm.py +++ b/alf/algorithms/sac_algorithm.py @@ -27,8 +27,7 @@ from alf.algorithms.config import TrainerConfig from alf.algorithms.off_policy_algorithm import OffPolicyAlgorithm from alf.algorithms.one_step_loss import OneStepTDLoss -from alf.data_structures import TimeStep, LossInfo, namedtuple, \ - BasicRLInfo +from alf.data_structures import TimeStep, LossInfo, namedtuple from alf.data_structures import AlgStep, StepType from alf.nest import nest import alf.nest.utils as nest_utils @@ -845,9 +844,8 @@ def _select_q_value(self, action, q_values): return q_values.gather(2, action).squeeze(2) def _critic_train_step(self, observation, target_observation, - state: SacCriticState, - rollout_info: SacInfo | BasicRLInfo, action, - action_distribution): + state: SacCriticState, rollout_info: SacInfo, + action, action_distribution): critics, critics_state = self._compute_critics( self._critic_networks, @@ -899,7 +897,7 @@ def _alpha_train_step(self, log_pi): return sum(nest.flatten(alpha_loss)) def train_step(self, inputs: TimeStep, state: SacState, - rollout_info: SacInfo | BasicRLInfo): + rollout_info: SacInfo): assert not self._is_eval self._training_started = True if self._target_repr_alg is not None: