File tree Expand file tree Collapse file tree 3 files changed +28
-1
lines changed
Expand file tree Collapse file tree 3 files changed +28
-1
lines changed Original file line number Diff line number Diff line change @@ -4398,6 +4398,28 @@ def test_reset_parameters_recursive(self, version):
43984398 )
43994399 self.reset_parameters_recursive_test(loss_fn)
44004400
4401+ def test_sac_list_qvalue_networks(self, version):
4402+ torch.manual_seed(self.seed)
4403+ td = self._create_mock_data_sac()
4404+ actor = self._create_mock_actor()
4405+ qvalue1 = self._create_mock_qvalue()
4406+ qvalue2 = self._create_mock_qvalue()
4407+ if version == 1:
4408+ value = self._create_mock_value()
4409+ else:
4410+ value = None
4411+ loss_fn = SACLoss(
4412+ actor_network=actor,
4413+ qvalue_network=[qvalue1, qvalue2],
4414+ value_network=value,
4415+ num_qvalue_nets=2,
4416+ )
4417+ with pytest.warns(
4418+ UserWarning, match="No target network updater has been associated"
4419+ ) if rl_warnings() else contextlib.nullcontext():
4420+ loss = loss_fn(td)
4421+ assert "loss_qvalue" in loss.keys()
4422+
44014423 @pytest.mark.parametrize("delay_value", (True, False))
44024424 @pytest.mark.parametrize("delay_actor", (True, False))
44034425 @pytest.mark.parametrize("delay_qvalue", (True, False))
Original file line number Diff line number Diff line change @@ -348,6 +348,8 @@ def convert_to_functional(
348348 params = TensorDict .from_modules (
349349 * module , as_module = True , expand_identical = True
350350 )
351+ # Use the first module as the functional forward reference.
352+ module = module [0 ]
351353 else :
352354 params = TensorDict .from_module (module , as_module = True )
353355
Original file line number Diff line number Diff line change @@ -73,12 +73,15 @@ class SACLoss(LossModule):
7373
7474 Args:
7575 actor_network (ProbabilisticTensorDictSequential): stochastic actor
76- qvalue_network (TensorDictModule): Q(s, a) parametric model.
76+ qvalue_network (TensorDictModule | list[TensorDictModule] ): Q(s, a) parametric model.
7777 This module typically outputs a ``"state_action_value"`` entry.
7878 If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets``
7979 times. If a list of modules is passed, their
8080 parameters will be stacked unless they share the same identity (in which case
8181 the original parameter will be expanded).
82+ When a list is provided, the first module is used as the functional forward
83+ reference (its ``in_keys``/``out_keys`` are used), so all modules must share
84+ the same signature.
8285
8386 .. warning:: When a list of parameters if passed, it will **not** be compared against the policy parameters
8487 and all the parameters will be considered as untied.
You can’t perform that action at this time.
0 commit comments