@@ -73,20 +73,22 @@ class IQLLoss(LossModule):
7373 ... in_keys=["loc", "scale"],
7474 ... spec=spec,
7575 ... distribution_class=TanhNormal)
76- >>> class ValueClass (nn.Module):
76+ >>> class QValueClass (nn.Module):
7777 ... def __init__(self):
7878 ... super().__init__()
7979 ... self.linear = nn.Linear(n_obs + n_act, 1)
8080 ... def forward(self, obs, act):
8181 ... return self.linear(torch.cat([obs, act], -1))
82- >>> module = ValueClass()
83- >>> qvalue = ValueOperator(
84- ... module=module,
85- ... in_keys=['observation', 'action'])
86- >>> module = nn.Linear(n_obs, 1)
87- >>> value = ValueOperator(
88- ... module=module,
89- ... in_keys=["observation"])
82+ >>> qvalue = SafeModule(
83+ ... QValueClass(),
84+ ... in_keys=["observation", "action"],
85+ ... out_keys=["state_action_value"],
86+ ... )
87+ >>> value = SafeModule(
88+ ... nn.Linear(n_obs, 1),
89+ ... in_keys=["observation"],
90+ ... out_keys=["state_value"],
91+ ... )
9092 >>> loss = IQLLoss(actor, qvalue, value)
9193 >>> batch = [2, ]
9294 >>> action = spec.rand(batch)
@@ -134,20 +136,22 @@ class IQLLoss(LossModule):
134136 ... in_keys=["loc", "scale"],
135137 ... spec=spec,
136138 ... distribution_class=TanhNormal)
137- >>> class ValueClass (nn.Module):
139+ >>> class QValueClass (nn.Module):
138140 ... def __init__(self):
139141 ... super().__init__()
140142 ... self.linear = nn.Linear(n_obs + n_act, 1)
141143 ... def forward(self, obs, act):
142144 ... return self.linear(torch.cat([obs, act], -1))
143- >>> module = ValueClass()
144- >>> qvalue = ValueOperator(
145- ... module=module,
146- ... in_keys=['observation', 'action'])
147- >>> module = nn.Linear(n_obs, 1)
148- >>> value = ValueOperator(
149- ... module=module,
150- ... in_keys=["observation"])
145+ >>> qvalue = SafeModule(
146+ ... QValueClass(),
147+ ... in_keys=["observation", "action"],
148+ ... out_keys=["state_action_value"],
149+ ... )
150+ >>> value = SafeModule(
151+ ... nn.Linear(n_obs, 1),
152+ ... in_keys=["observation"],
153+ ... out_keys=["state_value"],
154+ ... )
151155 >>> loss = IQLLoss(actor, qvalue, value)
152156 >>> batch = [2, ]
153157 >>> action = spec.rand(batch)
@@ -165,7 +169,7 @@ class IQLLoss(LossModule):
165169 method.
166170
167171 Examples:
168- >>> loss.select_out_keys('loss_actor', 'loss_qvalue')
172+ >>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue')
169173 >>> loss_actor, loss_qvalue = loss(
170174 ... observation=torch.randn(*batch, n_obs),
171175 ... action=action,
@@ -495,7 +499,7 @@ class DiscreteIQLLoss(IQLLoss):
495499
496500 Args:
497501 actor_network (ProbabilisticActor): stochastic actor
498- qvalue_network (TensorDictModule): Q(s) parametric model
502+ qvalue_network (TensorDictModule): Q(s, a ) parametric model.
499503 value_network (TensorDictModule, optional): V(s) parametric model.
500504
501505 Keyword Args:
@@ -526,34 +530,33 @@ class DiscreteIQLLoss(IQLLoss):
526530 >>> import torch
527531 >>> from torch import nn
528532 >>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec
529- >>> from torchrl.modules.distributions.continuous import NormalParamWrapper
530533 >>> from torchrl.modules.distributions.discrete import OneHotCategorical
531- >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
534+ >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor
532535 >>> from torchrl.modules.tensordict_module.common import SafeModule
533536 >>> from torchrl.objectives.iql import DiscreteIQLLoss
534537 >>> from tensordict import TensorDict
535538 >>> n_act, n_obs = 4, 3
536539 >>> spec = OneHotDiscreteTensorSpec(n_act)
537- >>> module = TensorDictModule (nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"])
540+ >>> module = SafeModule (nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"])
538541 >>> actor = ProbabilisticActor(
539542 ... module=module,
540543 ... in_keys=["logits"],
541544 ... out_keys=["action"],
542545 ... spec=spec,
543546 ... distribution_class=OneHotCategorical)
544- >>> qvalue = TensorDictModule (
545- ... nn.Linear(n_obs),
547+ >>> qvalue = SafeModule (
548+ ... nn.Linear(n_obs, n_act ),
546549 ... in_keys=["observation"],
547550 ... out_keys=["state_action_value"],
548551 ... )
549- >>> value = TensorDictModule (
550- ... nn.Linear(n_obs),
552+ >>> value = SafeModule (
553+ ... nn.Linear(n_obs, 1 ),
551554 ... in_keys=["observation"],
552555 ... out_keys=["state_value"],
553556 ... )
554557 >>> loss = DiscreteIQLLoss(actor, qvalue, value)
555558 >>> batch = [2, ]
556- >>> action = spec.rand(batch)
559+ >>> action = spec.rand(batch).long()
557560 >>> data = TensorDict({
558561 ... "observation": torch.randn(*batch, n_obs),
559562 ... "action": action,
@@ -585,40 +588,33 @@ class DiscreteIQLLoss(IQLLoss):
585588 >>> import torch
586589 >>> from torch import nn
587590 >>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec
588- >>> from torchrl.modules.distributions.continuous import NormalParamWrapper
589591 >>> from torchrl.modules.distributions.discrete import OneHotCategorical
590- >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
592+ >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor
591593 >>> from torchrl.modules.tensordict_module.common import SafeModule
592594 >>> from torchrl.objectives.iql import DiscreteIQLLoss
593- >>> from tensordict import TensorDict
594595 >>> _ = torch.manual_seed(42)
595596 >>> n_act, n_obs = 4, 3
596597 >>> spec = OneHotDiscreteTensorSpec(n_act)
597- >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
598- >>> module = SafeModule(net, in_keys=["observation"], out_keys=["logits"])
598+ >>> module = SafeModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"])
599599 >>> actor = ProbabilisticActor(
600600 ... module=module,
601601 ... in_keys=["logits"],
602602 ... out_keys=["action"],
603603 ... spec=spec,
604604 ... distribution_class=OneHotCategorical)
605- >>> class ValueClass(nn.Module):
606- ... def __init__(self):
607- ... super().__init__()
608- ... self.linear = nn.Linear(n_obs, n_act)
609- ... def forward(self, obs):
610- ... return self.linear(obs)
611- >>> module = ValueClass()
612- >>> qvalue = ValueOperator(
613- ... module=module,
614- ... in_keys=['observation'])
615- >>> module = nn.Linear(n_obs, 1)
616- >>> value = ValueOperator(
617- ... module=module,
618- ... in_keys=["observation"])
605+ >>> qvalue = SafeModule(
606+ ... nn.Linear(n_obs, n_act),
607+ ... in_keys=["observation"],
608+ ... out_keys=["state_action_value"],
609+ ... )
610+ >>> value = SafeModule(
611+ ... nn.Linear(n_obs, 1),
612+ ... in_keys=["observation"],
613+ ... out_keys=["state_value"],
614+ ... )
619615 >>> loss = DiscreteIQLLoss(actor, qvalue, value)
620616 >>> batch = [2, ]
621- >>> action = spec.rand(batch)
617+ >>> action = spec.rand(batch).long()
622618 >>> loss_actor, loss_qvalue, loss_value, entropy = loss(
623619 ... observation=torch.randn(*batch, n_obs),
624620 ... action=action,
@@ -633,7 +629,7 @@ class DiscreteIQLLoss(IQLLoss):
633629 method.
634630
635631 Examples:
636- >>> loss.select_out_keys('loss_actor', 'loss_qvalue', 'loss_value')
632+ >>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue', 'loss_value')
637633 >>> loss_actor, loss_qvalue, loss_value = loss(
638634 ... observation=torch.randn(*batch, n_obs),
639635 ... action=action,
0 commit comments