Skip to content

Commit 4160c9a

Browse files
[BugFix] Fixed broken SACLoss when there is more than one qvalue_network (#3500)
1 parent c472e9b commit 4160c9a

File tree

3 files changed

+28
-1
lines changed

3 files changed

+28
-1
lines changed

test/test_objectives.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4398,6 +4398,28 @@ def test_reset_parameters_recursive(self, version):
43984398
)
43994399
self.reset_parameters_recursive_test(loss_fn)
44004400

4401+
def test_sac_list_qvalue_networks(self, version):
4402+
torch.manual_seed(self.seed)
4403+
td = self._create_mock_data_sac()
4404+
actor = self._create_mock_actor()
4405+
qvalue1 = self._create_mock_qvalue()
4406+
qvalue2 = self._create_mock_qvalue()
4407+
if version == 1:
4408+
value = self._create_mock_value()
4409+
else:
4410+
value = None
4411+
loss_fn = SACLoss(
4412+
actor_network=actor,
4413+
qvalue_network=[qvalue1, qvalue2],
4414+
value_network=value,
4415+
num_qvalue_nets=2,
4416+
)
4417+
with pytest.warns(
4418+
UserWarning, match="No target network updater has been associated"
4419+
) if rl_warnings() else contextlib.nullcontext():
4420+
loss = loss_fn(td)
4421+
assert "loss_qvalue" in loss.keys()
4422+
44014423
@pytest.mark.parametrize("delay_value", (True, False))
44024424
@pytest.mark.parametrize("delay_actor", (True, False))
44034425
@pytest.mark.parametrize("delay_qvalue", (True, False))

torchrl/objectives/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,8 @@ def convert_to_functional(
348348
params = TensorDict.from_modules(
349349
*module, as_module=True, expand_identical=True
350350
)
351+
# Use the first module as the functional forward reference.
352+
module = module[0]
351353
else:
352354
params = TensorDict.from_module(module, as_module=True)
353355

torchrl/objectives/sac.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,15 @@ class SACLoss(LossModule):
7373
7474
Args:
7575
actor_network (ProbabilisticTensorDictSequential): stochastic actor
76-
qvalue_network (TensorDictModule): Q(s, a) parametric model.
76+
qvalue_network (TensorDictModule | list[TensorDictModule]): Q(s, a) parametric model.
7777
This module typically outputs a ``"state_action_value"`` entry.
7878
If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets``
7979
times. If a list of modules is passed, their
8080
parameters will be stacked unless they share the same identity (in which case
8181
the original parameter will be expanded).
82+
When a list is provided, the first module is used as the functional forward
83+
reference (its ``in_keys``/``out_keys`` are used), so all modules must share
84+
the same signature.
8285
8386
.. warning:: When a list of parameters if passed, it will **not** be compared against the policy parameters
8487
and all the parameters will be considered as untied.

0 commit comments

Comments
 (0)