Skip to content

Commit e4b6c92

Browse files
committed
[BugFix] Fix target_entropy computation for composite action specs
- Handle composite action distributions in SACLoss.target_entropy by summing numel of all leaf specs - Add warning in SafeProbabilisticModule when out_keys don't match spec structure (helps catch misconfigured CompositeDistribution setups) - Update tests to pass action_spec explicitly for composite action distributions - Improve docstrings for target_entropy parameter in SACLoss, REDQLoss ghstack-source-id: f5a5fa5 Pull-Request: #3312
1 parent 8c65aa7 commit e4b6c92

File tree

3 files changed

+235
-21
lines changed

3 files changed

+235
-21
lines changed

test/test_objectives.py

Lines changed: 193 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4127,6 +4127,7 @@ def _create_mock_actor(
41274127
observation_key="observation",
41284128
action_key="action",
41294129
composite_action_dist=False,
4130+
return_action_spec=False,
41304131
):
41314132
# Actor
41324133
action_spec = Bounded(
@@ -4161,7 +4162,10 @@ def _create_mock_actor(
41614162
spec=action_spec,
41624163
)
41634164
assert actor.log_prob_keys
4164-
return actor.to(device)
4165+
actor = actor.to(device)
4166+
if return_action_spec:
4167+
return actor, action_spec
4168+
return actor
41654169

41664170
def _create_mock_qvalue(
41674171
self,
@@ -4419,9 +4423,19 @@ def test_sac(
44194423
device=device, composite_action_dist=composite_action_dist
44204424
)
44214425

4422-
actor = self._create_mock_actor(
4423-
device=device, composite_action_dist=composite_action_dist
4424-
)
4426+
# For composite action distributions, we need to pass the action_spec
4427+
# explicitly because ProbabilisticActor doesn't preserve it properly
4428+
if composite_action_dist:
4429+
actor, action_spec = self._create_mock_actor(
4430+
device=device,
4431+
composite_action_dist=composite_action_dist,
4432+
return_action_spec=True,
4433+
)
4434+
else:
4435+
actor = self._create_mock_actor(
4436+
device=device, composite_action_dist=composite_action_dist
4437+
)
4438+
action_spec = None
44254439
qvalue = self._create_mock_qvalue(device=device)
44264440
if version == 1:
44274441
value = self._create_mock_value(device=device)
@@ -4442,6 +4456,7 @@ def test_sac(
44424456
value_network=value,
44434457
num_qvalue_nets=num_qvalue,
44444458
loss_function="l2",
4459+
action_spec=action_spec,
44454460
**kwargs,
44464461
)
44474462

@@ -4684,9 +4699,19 @@ def test_sac_state_dict(
46844699

46854700
torch.manual_seed(self.seed)
46864701

4687-
actor = self._create_mock_actor(
4688-
device=device, composite_action_dist=composite_action_dist
4689-
)
4702+
# For composite action distributions, we need to pass the action_spec
4703+
# explicitly because ProbabilisticActor doesn't preserve it properly
4704+
if composite_action_dist:
4705+
actor, action_spec = self._create_mock_actor(
4706+
device=device,
4707+
composite_action_dist=composite_action_dist,
4708+
return_action_spec=True,
4709+
)
4710+
else:
4711+
actor = self._create_mock_actor(
4712+
device=device, composite_action_dist=composite_action_dist
4713+
)
4714+
action_spec = None
46904715
qvalue = self._create_mock_qvalue(device=device)
46914716
if version == 1:
46924717
value = self._create_mock_value(device=device)
@@ -4707,6 +4732,7 @@ def test_sac_state_dict(
47074732
value_network=value,
47084733
num_qvalue_nets=num_qvalue,
47094734
loss_function="l2",
4735+
action_spec=action_spec,
47104736
**kwargs,
47114737
)
47124738
sd = loss_fn.state_dict()
@@ -4716,6 +4742,7 @@ def test_sac_state_dict(
47164742
value_network=value,
47174743
num_qvalue_nets=num_qvalue,
47184744
loss_function="l2",
4745+
action_spec=action_spec,
47194746
**kwargs,
47204747
)
47214748
loss_fn2.load_state_dict(sd)
@@ -4841,9 +4868,19 @@ def test_sac_batcher(
48414868
device=device, composite_action_dist=composite_action_dist
48424869
)
48434870

4844-
actor = self._create_mock_actor(
4845-
device=device, composite_action_dist=composite_action_dist
4846-
)
4871+
# For composite action distributions, we need to pass the action_spec
4872+
# explicitly because ProbabilisticActor doesn't preserve it properly
4873+
if composite_action_dist:
4874+
actor, action_spec = self._create_mock_actor(
4875+
device=device,
4876+
composite_action_dist=composite_action_dist,
4877+
return_action_spec=True,
4878+
)
4879+
else:
4880+
actor = self._create_mock_actor(
4881+
device=device, composite_action_dist=composite_action_dist
4882+
)
4883+
action_spec = None
48474884
qvalue = self._create_mock_qvalue(device=device)
48484885
if version == 1:
48494886
value = self._create_mock_value(device=device)
@@ -4864,6 +4901,7 @@ def test_sac_batcher(
48644901
value_network=value,
48654902
num_qvalue_nets=num_qvalue,
48664903
loss_function="l2",
4904+
action_spec=action_spec,
48674905
**kwargs,
48684906
)
48694907

@@ -4998,7 +5036,16 @@ def test_sac_batcher(
49985036
def test_sac_tensordict_keys(self, td_est, version, composite_action_dist):
49995037
td = self._create_mock_data_sac(composite_action_dist=composite_action_dist)
50005038

5001-
actor = self._create_mock_actor(composite_action_dist=composite_action_dist)
5039+
# For composite action distributions, we need to pass the action_spec
5040+
# explicitly because ProbabilisticActor doesn't preserve it properly
5041+
if composite_action_dist:
5042+
actor, action_spec = self._create_mock_actor(
5043+
composite_action_dist=composite_action_dist,
5044+
return_action_spec=True,
5045+
)
5046+
else:
5047+
actor = self._create_mock_actor(composite_action_dist=composite_action_dist)
5048+
action_spec = None
50025049
qvalue = self._create_mock_qvalue()
50035050
if version == 1:
50045051
value = self._create_mock_value()
@@ -5011,6 +5058,7 @@ def test_sac_tensordict_keys(self, td_est, version, composite_action_dist):
50115058
value_network=value,
50125059
num_qvalue_nets=2,
50135060
loss_function="l2",
5061+
action_spec=action_spec,
50145062
)
50155063

50165064
default_keys = {
@@ -5266,6 +5314,27 @@ def test_sac_target_entropy_auto(self, version, action_dim):
52665314
loss_fn.target_entropy.item() == -action_dim
52675315
), f"target_entropy should be -{action_dim}, got {loss_fn.target_entropy.item()}"
52685316

5317+
@pytest.mark.parametrize("target_entropy", [-1.0, -2.0, -5.0, 0.0])
5318+
def test_sac_target_entropy_explicit(self, version, target_entropy):
5319+
"""Regression test for explicit target_entropy values."""
5320+
torch.manual_seed(self.seed)
5321+
actor = self._create_mock_actor()
5322+
qvalue = self._create_mock_qvalue()
5323+
if version == 1:
5324+
value = self._create_mock_value()
5325+
else:
5326+
value = None
5327+
5328+
loss_fn = SACLoss(
5329+
actor_network=actor,
5330+
qvalue_network=qvalue,
5331+
value_network=value,
5332+
target_entropy=target_entropy,
5333+
)
5334+
assert (
5335+
loss_fn.target_entropy.item() == target_entropy
5336+
), f"target_entropy should be {target_entropy}, got {loss_fn.target_entropy.item()}"
5337+
52695338
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
52705339
@pytest.mark.parametrize("composite_action_dist", [True, False])
52715340
def test_sac_reduction(self, reduction, version, composite_action_dist):
@@ -5278,9 +5347,19 @@ def test_sac_reduction(self, reduction, version, composite_action_dist):
52785347
td = self._create_mock_data_sac(
52795348
device=device, composite_action_dist=composite_action_dist
52805349
)
5281-
actor = self._create_mock_actor(
5282-
device=device, composite_action_dist=composite_action_dist
5283-
)
5350+
# For composite action distributions, we need to pass the action_spec
5351+
# explicitly because ProbabilisticActor doesn't preserve it properly
5352+
if composite_action_dist:
5353+
actor, action_spec = self._create_mock_actor(
5354+
device=device,
5355+
composite_action_dist=composite_action_dist,
5356+
return_action_spec=True,
5357+
)
5358+
else:
5359+
actor = self._create_mock_actor(
5360+
device=device, composite_action_dist=composite_action_dist
5361+
)
5362+
action_spec = None
52845363
qvalue = self._create_mock_qvalue(device=device)
52855364
if version == 1:
52865365
value = self._create_mock_value(device=device)
@@ -5295,6 +5374,7 @@ def test_sac_reduction(self, reduction, version, composite_action_dist):
52955374
delay_actor=False,
52965375
delay_value=False,
52975376
reduction=reduction,
5377+
action_spec=action_spec,
52985378
)
52995379
loss_fn.make_value_estimator()
53005380
loss = loss_fn(td)
@@ -5825,6 +5905,29 @@ def test_discrete_sac_state_dict(
58255905
)
58265906
loss_fn2.load_state_dict(sd)
58275907

5908+
@pytest.mark.parametrize("action_dim", [2, 4, 8])
5909+
@pytest.mark.parametrize("target_entropy_weight", [0.5, 0.98])
5910+
def test_discrete_sac_target_entropy_auto(self, action_dim, target_entropy_weight):
5911+
"""Regression test for target_entropy='auto' in DiscreteSACLoss."""
5912+
import numpy as np
5913+
5914+
torch.manual_seed(self.seed)
5915+
actor = self._create_mock_actor(action_dim=action_dim)
5916+
qvalue = self._create_mock_qvalue(action_dim=action_dim)
5917+
5918+
loss_fn = DiscreteSACLoss(
5919+
actor_network=actor,
5920+
qvalue_network=qvalue,
5921+
num_actions=action_dim,
5922+
target_entropy_weight=target_entropy_weight,
5923+
action_space="one-hot",
5924+
)
5925+
# target_entropy="auto" should compute -log(1/num_actions) * target_entropy_weight
5926+
expected = -float(np.log(1.0 / action_dim) * target_entropy_weight)
5927+
assert (
5928+
abs(loss_fn.target_entropy.item() - expected) < 1e-5
5929+
), f"target_entropy should be {expected}, got {loss_fn.target_entropy.item()}"
5930+
58285931
@pytest.mark.parametrize("n", range(1, 4))
58295932
@pytest.mark.parametrize("delay_qvalue", (True, False))
58305933
@pytest.mark.parametrize("num_qvalue", [2])
@@ -6898,6 +7001,38 @@ def test_state_dict(
68987001
)
68997002
loss.load_state_dict(state)
69007003

7004+
@pytest.mark.parametrize("action_dim", [1, 2, 4, 8])
7005+
def test_crossq_target_entropy_auto(self, action_dim):
7006+
"""Regression test for target_entropy='auto' should be -dim(A)."""
7007+
torch.manual_seed(self.seed)
7008+
actor = self._create_mock_actor(action_dim=action_dim)
7009+
qvalue = self._create_mock_qvalue(action_dim=action_dim)
7010+
7011+
loss_fn = CrossQLoss(
7012+
actor_network=actor,
7013+
qvalue_network=qvalue,
7014+
)
7015+
# target_entropy="auto" should compute -action_dim
7016+
assert (
7017+
loss_fn.target_entropy.item() == -action_dim
7018+
), f"target_entropy should be -{action_dim}, got {loss_fn.target_entropy.item()}"
7019+
7020+
@pytest.mark.parametrize("target_entropy", [-1.0, -2.0, -5.0, 0.0])
7021+
def test_crossq_target_entropy_explicit(self, target_entropy):
7022+
"""Regression test for issue #3309: explicit target_entropy should work."""
7023+
torch.manual_seed(self.seed)
7024+
actor = self._create_mock_actor()
7025+
qvalue = self._create_mock_qvalue()
7026+
7027+
loss_fn = CrossQLoss(
7028+
actor_network=actor,
7029+
qvalue_network=qvalue,
7030+
target_entropy=target_entropy,
7031+
)
7032+
assert (
7033+
loss_fn.target_entropy.item() == target_entropy
7034+
), f"target_entropy should be {target_entropy}, got {loss_fn.target_entropy.item()}"
7035+
69017036
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
69027037
def test_crossq_reduction(self, reduction):
69037038
torch.manual_seed(self.seed)
@@ -7301,6 +7436,22 @@ def test_redq_state_dict(self, delay_qvalue, num_qvalue, device):
73017436
)
73027437
loss_fn2.load_state_dict(sd)
73037438

7439+
@pytest.mark.parametrize("action_dim", [1, 2, 4, 8])
7440+
def test_redq_target_entropy_auto(self, action_dim):
7441+
"""Regression test for target_entropy='auto' should be -dim(A)."""
7442+
torch.manual_seed(self.seed)
7443+
actor = self._create_mock_actor(action_dim=action_dim)
7444+
qvalue = self._create_mock_qvalue(action_dim=action_dim)
7445+
7446+
loss_fn = REDQLoss(
7447+
actor_network=actor,
7448+
qvalue_network=qvalue,
7449+
)
7450+
# target_entropy="auto" should compute -action_dim
7451+
assert (
7452+
loss_fn.target_entropy.item() == -action_dim
7453+
), f"target_entropy should be -{action_dim}, got {loss_fn.target_entropy.item()}"
7454+
73047455
@pytest.mark.parametrize("separate_losses", [False, True])
73057456
def test_redq_separate_losses(self, separate_losses):
73067457
torch.manual_seed(self.seed)
@@ -8378,6 +8529,22 @@ def test_cql_state_dict(
83788529
)
83798530
loss_fn2.load_state_dict(sd)
83808531

8532+
@pytest.mark.parametrize("action_dim", [1, 2, 4, 8])
8533+
def test_cql_target_entropy_auto(self, action_dim):
8534+
"""Regression test for target_entropy='auto' should be -dim(A)."""
8535+
torch.manual_seed(self.seed)
8536+
actor = self._create_mock_actor(action_dim=action_dim)
8537+
qvalue = self._create_mock_qvalue(action_dim=action_dim)
8538+
8539+
loss_fn = CQLLoss(
8540+
actor_network=actor,
8541+
qvalue_network=qvalue,
8542+
)
8543+
# target_entropy="auto" should compute -action_dim
8544+
assert (
8545+
loss_fn.target_entropy.item() == -action_dim
8546+
), f"target_entropy should be -{action_dim}, got {loss_fn.target_entropy.item()}"
8547+
83818548
@pytest.mark.parametrize("n", range(1, 4))
83828549
@pytest.mark.parametrize("delay_actor", (True, False))
83838550
@pytest.mark.parametrize("delay_qvalue", (True, False))
@@ -12390,6 +12557,18 @@ def test_odt_state_dict(self, device):
1239012557
loss_fn2 = OnlineDTLoss(actor)
1239112558
loss_fn2.load_state_dict(sd)
1239212559

12560+
@pytest.mark.parametrize("action_dim", [1, 2, 4, 8])
12561+
def test_odt_target_entropy_auto(self, action_dim):
12562+
"""Regression test for target_entropy='auto' should be -dim(A)."""
12563+
torch.manual_seed(self.seed)
12564+
actor = self._create_mock_actor(action_dim=action_dim)
12565+
12566+
loss_fn = OnlineDTLoss(actor)
12567+
# target_entropy="auto" should compute -action_dim
12568+
assert (
12569+
loss_fn.target_entropy.item() == -action_dim
12570+
), f"target_entropy should be -{action_dim}, got {loss_fn.target_entropy.item()}"
12571+
1239312572
@pytest.mark.parametrize("device", get_available_devices())
1239412573
def test_seq_odt(self, device):
1239512574
torch.manual_seed(self.seed)

torchrl/objectives/redq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ class REDQLoss(LossModule):
6868
fixed_alpha (bool, optional): whether alpha should be trained to match
6969
a target entropy. Default is ``False``.
7070
target_entropy (Union[str, Number], optional): Target entropy for the
71-
stochastic policy. Default is "auto".
71+
stochastic policy. Default is "auto", where target entropy is
72+
computed as :obj:`-prod(n_actions)`.
7273
delay_qvalue (bool, optional): Whether to separate the target Q value
7374
networks from the Q value networks used
7475
for data collection. Default is ``False``.

0 commit comments

Comments
 (0)