@@ -45,6 +45,7 @@ class CQLLoss(LossModule):
4545 actor_network (ProbabilisticActor): stochastic actor
4646 qvalue_network (TensorDictModule): Q(s, a) parametric model.
4747 This module typically outputs a ``"state_action_value"`` entry.
48+
4849 Keyword args:
4950 loss_function (str, optional): loss function to be used with
5051 the value function loss. Default is `"smooth_l1"`.
@@ -127,8 +128,9 @@ class CQLLoss(LossModule):
127128 alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
128129 entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
129130 loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
131+ loss_actor_bc: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
130132 loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
131- loss_alpha_prime : Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
133+ loss_cql : Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
132134 loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
133135 batch_size=torch.Size([]),
134136 device=None,
@@ -169,10 +171,10 @@ class CQLLoss(LossModule):
169171 >>> qvalue = ValueOperator(
170172 ... module=module,
171173 ... in_keys=['observation', 'action'])
172- >>> loss = CQLLoss(actor, qvalue, value )
174+ >>> loss = CQLLoss(actor, qvalue)
173175 >>> batch = [2, ]
174176 >>> action = spec.rand(batch)
175- >>> loss_actor, loss_qvalue, _, _, _, _ = loss(
177+ >>> loss_actor, loss_actor_bc, loss_qvalue, loss_cql, * _ = loss(
176178 ... observation=torch.randn(*batch, n_obs),
177179 ... action=action,
178180 ... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
@@ -185,7 +187,7 @@ class CQLLoss(LossModule):
185187 method.
186188
187189 Examples:
188- >>> loss.select_out_keys('loss_actor', 'loss_qvalue')
190+ >>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue')
189191 >>> loss_actor, loss_qvalue = loss(
190192 ... observation=torch.randn(*batch, n_obs),
191193 ... action=action,
@@ -471,10 +473,11 @@ def out_keys(self):
471473 "loss_qvalue" ,
472474 "loss_cql" ,
473475 "loss_alpha" ,
474- "loss_alpha_prime" ,
475476 "alpha" ,
476477 "entropy" ,
477478 ]
479+ if self .with_lagrange :
480+ keys .append ("loss_alpha_prime" )
478481 self ._out_keys = keys
479482 return self ._out_keys
480483
@@ -876,8 +879,9 @@ class DiscreteCQLLoss(LossModule):
876879
877880
878881 Examples:
879- >>> from torchrl.modules import MLP
882+ >>> from torchrl.modules import MLP, QValueActor
880883 >>> from torchrl.data import OneHotDiscreteTensorSpec
884+ >>> from torchrl.objectives import DiscreteCQLLoss
881885 >>> n_obs, n_act = 4, 3
882886 >>> value_net = MLP(in_features=n_obs, out_features=n_act)
883887 >>> spec = OneHotDiscreteTensorSpec(n_act)
@@ -895,8 +899,11 @@ class DiscreteCQLLoss(LossModule):
895899 >>> loss(data)
896900 TensorDict(
897901 fields={
898- loss: Tensor(shape=torch.Size([]), device=cuda:0, dtype=torch.float32, is_shared=True),
899- loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
902+ loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
903+ loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
904+ pred_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
905+ target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
906+ td_error: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False)},
900907 batch_size=torch.Size([]),
901908 device=None,
902909 is_shared=False)
0 commit comments