Skip to content

Commit bfb4037

Browse files
BY571vmoens
andcommitted
[BugFix] Update iql docstring example (#1950)
Co-authored-by: Vincent Moens <[email protected]>
1 parent fb3d8cc commit bfb4037

File tree

1 file changed

+45
-49
lines changed

1 file changed

+45
-49
lines changed

torchrl/objectives/iql.py

Lines changed: 45 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -73,20 +73,22 @@ class IQLLoss(LossModule):
7373
... in_keys=["loc", "scale"],
7474
... spec=spec,
7575
... distribution_class=TanhNormal)
76-
>>> class ValueClass(nn.Module):
76+
>>> class QValueClass(nn.Module):
7777
... def __init__(self):
7878
... super().__init__()
7979
... self.linear = nn.Linear(n_obs + n_act, 1)
8080
... def forward(self, obs, act):
8181
... return self.linear(torch.cat([obs, act], -1))
82-
>>> module = ValueClass()
83-
>>> qvalue = ValueOperator(
84-
... module=module,
85-
... in_keys=['observation', 'action'])
86-
>>> module = nn.Linear(n_obs, 1)
87-
>>> value = ValueOperator(
88-
... module=module,
89-
... in_keys=["observation"])
82+
>>> qvalue = SafeModule(
83+
... QValueClass(),
84+
... in_keys=["observation", "action"],
85+
... out_keys=["state_action_value"],
86+
... )
87+
>>> value = SafeModule(
88+
... nn.Linear(n_obs, 1),
89+
... in_keys=["observation"],
90+
... out_keys=["state_value"],
91+
... )
9092
>>> loss = IQLLoss(actor, qvalue, value)
9193
>>> batch = [2, ]
9294
>>> action = spec.rand(batch)
@@ -134,20 +136,22 @@ class IQLLoss(LossModule):
134136
... in_keys=["loc", "scale"],
135137
... spec=spec,
136138
... distribution_class=TanhNormal)
137-
>>> class ValueClass(nn.Module):
139+
>>> class QValueClass(nn.Module):
138140
... def __init__(self):
139141
... super().__init__()
140142
... self.linear = nn.Linear(n_obs + n_act, 1)
141143
... def forward(self, obs, act):
142144
... return self.linear(torch.cat([obs, act], -1))
143-
>>> module = ValueClass()
144-
>>> qvalue = ValueOperator(
145-
... module=module,
146-
... in_keys=['observation', 'action'])
147-
>>> module = nn.Linear(n_obs, 1)
148-
>>> value = ValueOperator(
149-
... module=module,
150-
... in_keys=["observation"])
145+
>>> qvalue = SafeModule(
146+
... QValueClass(),
147+
... in_keys=["observation", "action"],
148+
... out_keys=["state_action_value"],
149+
... )
150+
>>> value = SafeModule(
151+
... nn.Linear(n_obs, 1),
152+
... in_keys=["observation"],
153+
... out_keys=["state_value"],
154+
... )
151155
>>> loss = IQLLoss(actor, qvalue, value)
152156
>>> batch = [2, ]
153157
>>> action = spec.rand(batch)
@@ -165,7 +169,7 @@ class IQLLoss(LossModule):
165169
method.
166170
167171
Examples:
168-
>>> loss.select_out_keys('loss_actor', 'loss_qvalue')
172+
>>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue')
169173
>>> loss_actor, loss_qvalue = loss(
170174
... observation=torch.randn(*batch, n_obs),
171175
... action=action,
@@ -495,7 +499,7 @@ class DiscreteIQLLoss(IQLLoss):
495499
496500
Args:
497501
actor_network (ProbabilisticActor): stochastic actor
498-
qvalue_network (TensorDictModule): Q(s) parametric model
502+
qvalue_network (TensorDictModule): Q(s, a) parametric model.
499503
value_network (TensorDictModule, optional): V(s) parametric model.
500504
501505
Keyword Args:
@@ -526,34 +530,33 @@ class DiscreteIQLLoss(IQLLoss):
526530
>>> import torch
527531
>>> from torch import nn
528532
>>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec
529-
>>> from torchrl.modules.distributions.continuous import NormalParamWrapper
530533
>>> from torchrl.modules.distributions.discrete import OneHotCategorical
531-
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
534+
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor
532535
>>> from torchrl.modules.tensordict_module.common import SafeModule
533536
>>> from torchrl.objectives.iql import DiscreteIQLLoss
534537
>>> from tensordict import TensorDict
535538
>>> n_act, n_obs = 4, 3
536539
>>> spec = OneHotDiscreteTensorSpec(n_act)
537-
>>> module = TensorDictModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"])
540+
>>> module = SafeModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"])
538541
>>> actor = ProbabilisticActor(
539542
... module=module,
540543
... in_keys=["logits"],
541544
... out_keys=["action"],
542545
... spec=spec,
543546
... distribution_class=OneHotCategorical)
544-
>>> qvalue = TensorDictModule(
545-
... nn.Linear(n_obs),
547+
>>> qvalue = SafeModule(
548+
... nn.Linear(n_obs, n_act),
546549
... in_keys=["observation"],
547550
... out_keys=["state_action_value"],
548551
... )
549-
>>> value = TensorDictModule(
550-
... nn.Linear(n_obs),
552+
>>> value = SafeModule(
553+
... nn.Linear(n_obs, 1),
551554
... in_keys=["observation"],
552555
... out_keys=["state_value"],
553556
... )
554557
>>> loss = DiscreteIQLLoss(actor, qvalue, value)
555558
>>> batch = [2, ]
556-
>>> action = spec.rand(batch)
559+
>>> action = spec.rand(batch).long()
557560
>>> data = TensorDict({
558561
... "observation": torch.randn(*batch, n_obs),
559562
... "action": action,
@@ -585,40 +588,33 @@ class DiscreteIQLLoss(IQLLoss):
585588
>>> import torch
586589
>>> from torch import nn
587590
>>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec
588-
>>> from torchrl.modules.distributions.continuous import NormalParamWrapper
589591
>>> from torchrl.modules.distributions.discrete import OneHotCategorical
590-
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
592+
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor
591593
>>> from torchrl.modules.tensordict_module.common import SafeModule
592594
>>> from torchrl.objectives.iql import DiscreteIQLLoss
593-
>>> from tensordict import TensorDict
594595
>>> _ = torch.manual_seed(42)
595596
>>> n_act, n_obs = 4, 3
596597
>>> spec = OneHotDiscreteTensorSpec(n_act)
597-
>>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
598-
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["logits"])
598+
>>> module = SafeModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"])
599599
>>> actor = ProbabilisticActor(
600600
... module=module,
601601
... in_keys=["logits"],
602602
... out_keys=["action"],
603603
... spec=spec,
604604
... distribution_class=OneHotCategorical)
605-
>>> class ValueClass(nn.Module):
606-
... def __init__(self):
607-
... super().__init__()
608-
... self.linear = nn.Linear(n_obs, n_act)
609-
... def forward(self, obs):
610-
... return self.linear(obs)
611-
>>> module = ValueClass()
612-
>>> qvalue = ValueOperator(
613-
... module=module,
614-
... in_keys=['observation'])
615-
>>> module = nn.Linear(n_obs, 1)
616-
>>> value = ValueOperator(
617-
... module=module,
618-
... in_keys=["observation"])
605+
>>> qvalue = SafeModule(
606+
... nn.Linear(n_obs, n_act),
607+
... in_keys=["observation"],
608+
... out_keys=["state_action_value"],
609+
... )
610+
>>> value = SafeModule(
611+
... nn.Linear(n_obs, 1),
612+
... in_keys=["observation"],
613+
... out_keys=["state_value"],
614+
... )
619615
>>> loss = DiscreteIQLLoss(actor, qvalue, value)
620616
>>> batch = [2, ]
621-
>>> action = spec.rand(batch)
617+
>>> action = spec.rand(batch).long()
622618
>>> loss_actor, loss_qvalue, loss_value, entropy = loss(
623619
... observation=torch.randn(*batch, n_obs),
624620
... action=action,
@@ -633,7 +629,7 @@ class DiscreteIQLLoss(IQLLoss):
633629
method.
634630
635631
Examples:
636-
>>> loss.select_out_keys('loss_actor', 'loss_qvalue', 'loss_value')
632+
>>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue', 'loss_value')
637633
>>> loss_actor, loss_qvalue, loss_value = loss(
638634
... observation=torch.randn(*batch, n_obs),
639635
... action=action,

0 commit comments

Comments
 (0)