diff --git a/alf/algorithms/algorithm.py b/alf/algorithms/algorithm.py index 19a6db97b..0fae27277 100644 --- a/alf/algorithms/algorithm.py +++ b/alf/algorithms/algorithm.py @@ -29,7 +29,7 @@ from torch.nn.modules.module import _IncompatibleKeys, _addindent import alf -from alf.data_structures import AlgStep, LossInfo, StepType, TimeStep +from alf.data_structures import AlgStep, LossInfo, StepType, TimeStep, BasicRolloutInfo from alf.experience_replayers.replay_buffer import BatchInfo, ReplayBuffer from alf.optimizers.utils import GradientNoiseScaleEstimator from alf.utils.checkpoint_utils import (is_checkpoint_enabled, @@ -1368,6 +1368,13 @@ def train_step_offline(self, inputs, state, rollout_info, pre_train=False): customized training. """ try: + if isinstance(rollout_info, BasicRolloutInfo): + logging.log_first_n( + logging.WARNING, + "Detected offline buffer training without Agent wrapper. " + "For best compatibility, it is advised to use the Agent wrapper.", + n=1) + rollout_info = rollout_info.rl return self.train_step(inputs, state, rollout_info) except: # the default train_step is not compatible with the diff --git a/alf/algorithms/sac_algorithm.py b/alf/algorithms/sac_algorithm.py index f2363bd41..5044ba3e9 100644 --- a/alf/algorithms/sac_algorithm.py +++ b/alf/algorithms/sac_algorithm.py @@ -27,13 +27,12 @@ from alf.algorithms.config import TrainerConfig from alf.algorithms.off_policy_algorithm import OffPolicyAlgorithm from alf.algorithms.one_step_loss import OneStepTDLoss -from alf.algorithms.rl_algorithm import RLAlgorithm -from alf.data_structures import TimeStep, Experience, LossInfo, namedtuple +from alf.data_structures import TimeStep, LossInfo, namedtuple from alf.data_structures import AlgStep, StepType from alf.nest import nest import alf.nest.utils as nest_utils from alf.networks import ActorDistributionNetwork, CriticNetwork -from alf.networks import QNetwork, QRNNNetwork +from alf.networks import QNetwork from alf.tensor_specs import TensorSpec, BoundedTensorSpec from alf.utils import losses, common, dist_utils, math_ops from alf.utils.normalizers import ScalarAdaptiveNormalizer @@ -847,6 +846,7 @@ def _select_q_value(self, action, q_values): def _critic_train_step(self, observation, target_observation, state: SacCriticState, rollout_info: SacInfo, action, action_distribution): + critics, critics_state = self._compute_critics( self._critic_networks, observation,