Skip to content

Commit fb3d8cc

Browse files
BY571vmoens
andcommitted
[BugFix] Update cql docstring example (#1951)
Co-authored-by: Vincent Moens <[email protected]>
1 parent ad02db6 commit fb3d8cc

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

torchrl/objectives/cql.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class CQLLoss(LossModule):
4545
actor_network (ProbabilisticActor): stochastic actor
4646
qvalue_network (TensorDictModule): Q(s, a) parametric model.
4747
This module typically outputs a ``"state_action_value"`` entry.
48+
4849
Keyword args:
4950
loss_function (str, optional): loss function to be used with
5051
the value function loss. Default is `"smooth_l1"`.
@@ -127,8 +128,9 @@ class CQLLoss(LossModule):
127128
alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
128129
entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
129130
loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
131+
loss_actor_bc: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
130132
loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
131-
loss_alpha_prime: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
133+
loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
132134
loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
133135
batch_size=torch.Size([]),
134136
device=None,
@@ -169,10 +171,10 @@ class CQLLoss(LossModule):
169171
>>> qvalue = ValueOperator(
170172
... module=module,
171173
... in_keys=['observation', 'action'])
172-
>>> loss = CQLLoss(actor, qvalue, value)
174+
>>> loss = CQLLoss(actor, qvalue)
173175
>>> batch = [2, ]
174176
>>> action = spec.rand(batch)
175-
>>> loss_actor, loss_qvalue, _, _, _, _ = loss(
177+
>>> loss_actor, loss_actor_bc, loss_qvalue, loss_cql, *_ = loss(
176178
... observation=torch.randn(*batch, n_obs),
177179
... action=action,
178180
... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
@@ -185,7 +187,7 @@ class CQLLoss(LossModule):
185187
method.
186188
187189
Examples:
188-
>>> loss.select_out_keys('loss_actor', 'loss_qvalue')
190+
>>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue')
189191
>>> loss_actor, loss_qvalue = loss(
190192
... observation=torch.randn(*batch, n_obs),
191193
... action=action,
@@ -471,10 +473,11 @@ def out_keys(self):
471473
"loss_qvalue",
472474
"loss_cql",
473475
"loss_alpha",
474-
"loss_alpha_prime",
475476
"alpha",
476477
"entropy",
477478
]
479+
if self.with_lagrange:
480+
keys.append("loss_alpha_prime")
478481
self._out_keys = keys
479482
return self._out_keys
480483

@@ -876,8 +879,9 @@ class DiscreteCQLLoss(LossModule):
876879
877880
878881
Examples:
879-
>>> from torchrl.modules import MLP
882+
>>> from torchrl.modules import MLP, QValueActor
880883
>>> from torchrl.data import OneHotDiscreteTensorSpec
884+
>>> from torchrl.objectives import DiscreteCQLLoss
881885
>>> n_obs, n_act = 4, 3
882886
>>> value_net = MLP(in_features=n_obs, out_features=n_act)
883887
>>> spec = OneHotDiscreteTensorSpec(n_act)
@@ -895,8 +899,11 @@ class DiscreteCQLLoss(LossModule):
895899
>>> loss(data)
896900
TensorDict(
897901
fields={
898-
loss: Tensor(shape=torch.Size([]), device=cuda:0, dtype=torch.float32, is_shared=True),
899-
loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
902+
loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
903+
loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
904+
pred_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
905+
target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
906+
td_error: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False)},
900907
batch_size=torch.Size([]),
901908
device=None,
902909
is_shared=False)

0 commit comments

Comments
 (0)