@@ -4127,6 +4127,7 @@ def _create_mock_actor(
41274127 observation_key="observation",
41284128 action_key="action",
41294129 composite_action_dist=False,
4130+ return_action_spec=False,
41304131 ):
41314132 # Actor
41324133 action_spec = Bounded(
@@ -4161,7 +4162,10 @@ def _create_mock_actor(
41614162 spec=action_spec,
41624163 )
41634164 assert actor.log_prob_keys
4164- return actor.to(device)
4165+ actor = actor.to(device)
4166+ if return_action_spec:
4167+ return actor, action_spec
4168+ return actor
41654169
41664170 def _create_mock_qvalue(
41674171 self,
@@ -4419,9 +4423,19 @@ def test_sac(
44194423 device=device, composite_action_dist=composite_action_dist
44204424 )
44214425
4422- actor = self._create_mock_actor(
4423- device=device, composite_action_dist=composite_action_dist
4424- )
4426+ # For composite action distributions, we need to pass the action_spec
4427+ # explicitly because ProbabilisticActor doesn't preserve it properly
4428+ if composite_action_dist:
4429+ actor, action_spec = self._create_mock_actor(
4430+ device=device,
4431+ composite_action_dist=composite_action_dist,
4432+ return_action_spec=True,
4433+ )
4434+ else:
4435+ actor = self._create_mock_actor(
4436+ device=device, composite_action_dist=composite_action_dist
4437+ )
4438+ action_spec = None
44254439 qvalue = self._create_mock_qvalue(device=device)
44264440 if version == 1:
44274441 value = self._create_mock_value(device=device)
@@ -4442,6 +4456,7 @@ def test_sac(
44424456 value_network=value,
44434457 num_qvalue_nets=num_qvalue,
44444458 loss_function="l2",
4459+ action_spec=action_spec,
44454460 **kwargs,
44464461 )
44474462
@@ -4684,9 +4699,19 @@ def test_sac_state_dict(
46844699
46854700 torch.manual_seed(self.seed)
46864701
4687- actor = self._create_mock_actor(
4688- device=device, composite_action_dist=composite_action_dist
4689- )
4702+ # For composite action distributions, we need to pass the action_spec
4703+ # explicitly because ProbabilisticActor doesn't preserve it properly
4704+ if composite_action_dist:
4705+ actor, action_spec = self._create_mock_actor(
4706+ device=device,
4707+ composite_action_dist=composite_action_dist,
4708+ return_action_spec=True,
4709+ )
4710+ else:
4711+ actor = self._create_mock_actor(
4712+ device=device, composite_action_dist=composite_action_dist
4713+ )
4714+ action_spec = None
46904715 qvalue = self._create_mock_qvalue(device=device)
46914716 if version == 1:
46924717 value = self._create_mock_value(device=device)
@@ -4707,6 +4732,7 @@ def test_sac_state_dict(
47074732 value_network=value,
47084733 num_qvalue_nets=num_qvalue,
47094734 loss_function="l2",
4735+ action_spec=action_spec,
47104736 **kwargs,
47114737 )
47124738 sd = loss_fn.state_dict()
@@ -4716,6 +4742,7 @@ def test_sac_state_dict(
47164742 value_network=value,
47174743 num_qvalue_nets=num_qvalue,
47184744 loss_function="l2",
4745+ action_spec=action_spec,
47194746 **kwargs,
47204747 )
47214748 loss_fn2.load_state_dict(sd)
@@ -4841,9 +4868,19 @@ def test_sac_batcher(
48414868 device=device, composite_action_dist=composite_action_dist
48424869 )
48434870
4844- actor = self._create_mock_actor(
4845- device=device, composite_action_dist=composite_action_dist
4846- )
4871+ # For composite action distributions, we need to pass the action_spec
4872+ # explicitly because ProbabilisticActor doesn't preserve it properly
4873+ if composite_action_dist:
4874+ actor, action_spec = self._create_mock_actor(
4875+ device=device,
4876+ composite_action_dist=composite_action_dist,
4877+ return_action_spec=True,
4878+ )
4879+ else:
4880+ actor = self._create_mock_actor(
4881+ device=device, composite_action_dist=composite_action_dist
4882+ )
4883+ action_spec = None
48474884 qvalue = self._create_mock_qvalue(device=device)
48484885 if version == 1:
48494886 value = self._create_mock_value(device=device)
@@ -4864,6 +4901,7 @@ def test_sac_batcher(
48644901 value_network=value,
48654902 num_qvalue_nets=num_qvalue,
48664903 loss_function="l2",
4904+ action_spec=action_spec,
48674905 **kwargs,
48684906 )
48694907
@@ -4998,7 +5036,16 @@ def test_sac_batcher(
49985036 def test_sac_tensordict_keys(self, td_est, version, composite_action_dist):
49995037 td = self._create_mock_data_sac(composite_action_dist=composite_action_dist)
50005038
5001- actor = self._create_mock_actor(composite_action_dist=composite_action_dist)
5039+ # For composite action distributions, we need to pass the action_spec
5040+ # explicitly because ProbabilisticActor doesn't preserve it properly
5041+ if composite_action_dist:
5042+ actor, action_spec = self._create_mock_actor(
5043+ composite_action_dist=composite_action_dist,
5044+ return_action_spec=True,
5045+ )
5046+ else:
5047+ actor = self._create_mock_actor(composite_action_dist=composite_action_dist)
5048+ action_spec = None
50025049 qvalue = self._create_mock_qvalue()
50035050 if version == 1:
50045051 value = self._create_mock_value()
@@ -5011,6 +5058,7 @@ def test_sac_tensordict_keys(self, td_est, version, composite_action_dist):
50115058 value_network=value,
50125059 num_qvalue_nets=2,
50135060 loss_function="l2",
5061+ action_spec=action_spec,
50145062 )
50155063
50165064 default_keys = {
@@ -5266,6 +5314,27 @@ def test_sac_target_entropy_auto(self, version, action_dim):
52665314 loss_fn.target_entropy.item() == -action_dim
52675315 ), f"target_entropy should be -{action_dim}, got {loss_fn.target_entropy.item()}"
52685316
5317+ @pytest.mark.parametrize("target_entropy", [-1.0, -2.0, -5.0, 0.0])
5318+ def test_sac_target_entropy_explicit(self, version, target_entropy):
5319+ """Regression test for explicit target_entropy values."""
5320+ torch.manual_seed(self.seed)
5321+ actor = self._create_mock_actor()
5322+ qvalue = self._create_mock_qvalue()
5323+ if version == 1:
5324+ value = self._create_mock_value()
5325+ else:
5326+ value = None
5327+
5328+ loss_fn = SACLoss(
5329+ actor_network=actor,
5330+ qvalue_network=qvalue,
5331+ value_network=value,
5332+ target_entropy=target_entropy,
5333+ )
5334+ assert (
5335+ loss_fn.target_entropy.item() == target_entropy
5336+ ), f"target_entropy should be {target_entropy}, got {loss_fn.target_entropy.item()}"
5337+
52695338 @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
52705339 @pytest.mark.parametrize("composite_action_dist", [True, False])
52715340 def test_sac_reduction(self, reduction, version, composite_action_dist):
@@ -5278,9 +5347,19 @@ def test_sac_reduction(self, reduction, version, composite_action_dist):
52785347 td = self._create_mock_data_sac(
52795348 device=device, composite_action_dist=composite_action_dist
52805349 )
5281- actor = self._create_mock_actor(
5282- device=device, composite_action_dist=composite_action_dist
5283- )
5350+ # For composite action distributions, we need to pass the action_spec
5351+ # explicitly because ProbabilisticActor doesn't preserve it properly
5352+ if composite_action_dist:
5353+ actor, action_spec = self._create_mock_actor(
5354+ device=device,
5355+ composite_action_dist=composite_action_dist,
5356+ return_action_spec=True,
5357+ )
5358+ else:
5359+ actor = self._create_mock_actor(
5360+ device=device, composite_action_dist=composite_action_dist
5361+ )
5362+ action_spec = None
52845363 qvalue = self._create_mock_qvalue(device=device)
52855364 if version == 1:
52865365 value = self._create_mock_value(device=device)
@@ -5295,6 +5374,7 @@ def test_sac_reduction(self, reduction, version, composite_action_dist):
52955374 delay_actor=False,
52965375 delay_value=False,
52975376 reduction=reduction,
5377+ action_spec=action_spec,
52985378 )
52995379 loss_fn.make_value_estimator()
53005380 loss = loss_fn(td)
@@ -5825,6 +5905,29 @@ def test_discrete_sac_state_dict(
58255905 )
58265906 loss_fn2.load_state_dict(sd)
58275907
5908+ @pytest.mark.parametrize("action_dim", [2, 4, 8])
5909+ @pytest.mark.parametrize("target_entropy_weight", [0.5, 0.98])
5910+ def test_discrete_sac_target_entropy_auto(self, action_dim, target_entropy_weight):
5911+ """Regression test for target_entropy='auto' in DiscreteSACLoss."""
5912+ import numpy as np
5913+
5914+ torch.manual_seed(self.seed)
5915+ actor = self._create_mock_actor(action_dim=action_dim)
5916+ qvalue = self._create_mock_qvalue(action_dim=action_dim)
5917+
5918+ loss_fn = DiscreteSACLoss(
5919+ actor_network=actor,
5920+ qvalue_network=qvalue,
5921+ num_actions=action_dim,
5922+ target_entropy_weight=target_entropy_weight,
5923+ action_space="one-hot",
5924+ )
5925+ # target_entropy="auto" should compute -log(1/num_actions) * target_entropy_weight
5926+ expected = -float(np.log(1.0 / action_dim) * target_entropy_weight)
5927+ assert (
5928+ abs(loss_fn.target_entropy.item() - expected) < 1e-5
5929+ ), f"target_entropy should be {expected}, got {loss_fn.target_entropy.item()}"
5930+
58285931 @pytest.mark.parametrize("n", range(1, 4))
58295932 @pytest.mark.parametrize("delay_qvalue", (True, False))
58305933 @pytest.mark.parametrize("num_qvalue", [2])
@@ -6898,6 +7001,38 @@ def test_state_dict(
68987001 )
68997002 loss.load_state_dict(state)
69007003
7004+ @pytest.mark.parametrize("action_dim", [1, 2, 4, 8])
7005+ def test_crossq_target_entropy_auto(self, action_dim):
7006+ """Regression test for target_entropy='auto' should be -dim(A)."""
7007+ torch.manual_seed(self.seed)
7008+ actor = self._create_mock_actor(action_dim=action_dim)
7009+ qvalue = self._create_mock_qvalue(action_dim=action_dim)
7010+
7011+ loss_fn = CrossQLoss(
7012+ actor_network=actor,
7013+ qvalue_network=qvalue,
7014+ )
7015+ # target_entropy="auto" should compute -action_dim
7016+ assert (
7017+ loss_fn.target_entropy.item() == -action_dim
7018+ ), f"target_entropy should be -{action_dim}, got {loss_fn.target_entropy.item()}"
7019+
7020+ @pytest.mark.parametrize("target_entropy", [-1.0, -2.0, -5.0, 0.0])
7021+ def test_crossq_target_entropy_explicit(self, target_entropy):
7022+ """Regression test for issue #3309: explicit target_entropy should work."""
7023+ torch.manual_seed(self.seed)
7024+ actor = self._create_mock_actor()
7025+ qvalue = self._create_mock_qvalue()
7026+
7027+ loss_fn = CrossQLoss(
7028+ actor_network=actor,
7029+ qvalue_network=qvalue,
7030+ target_entropy=target_entropy,
7031+ )
7032+ assert (
7033+ loss_fn.target_entropy.item() == target_entropy
7034+ ), f"target_entropy should be {target_entropy}, got {loss_fn.target_entropy.item()}"
7035+
69017036 @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
69027037 def test_crossq_reduction(self, reduction):
69037038 torch.manual_seed(self.seed)
@@ -7301,6 +7436,22 @@ def test_redq_state_dict(self, delay_qvalue, num_qvalue, device):
73017436 )
73027437 loss_fn2.load_state_dict(sd)
73037438
7439+ @pytest.mark.parametrize("action_dim", [1, 2, 4, 8])
7440+ def test_redq_target_entropy_auto(self, action_dim):
7441+ """Regression test for target_entropy='auto' should be -dim(A)."""
7442+ torch.manual_seed(self.seed)
7443+ actor = self._create_mock_actor(action_dim=action_dim)
7444+ qvalue = self._create_mock_qvalue(action_dim=action_dim)
7445+
7446+ loss_fn = REDQLoss(
7447+ actor_network=actor,
7448+ qvalue_network=qvalue,
7449+ )
7450+ # target_entropy="auto" should compute -action_dim
7451+ assert (
7452+ loss_fn.target_entropy.item() == -action_dim
7453+ ), f"target_entropy should be -{action_dim}, got {loss_fn.target_entropy.item()}"
7454+
73047455 @pytest.mark.parametrize("separate_losses", [False, True])
73057456 def test_redq_separate_losses(self, separate_losses):
73067457 torch.manual_seed(self.seed)
@@ -8378,6 +8529,22 @@ def test_cql_state_dict(
83788529 )
83798530 loss_fn2.load_state_dict(sd)
83808531
8532+ @pytest.mark.parametrize("action_dim", [1, 2, 4, 8])
8533+ def test_cql_target_entropy_auto(self, action_dim):
8534+ """Regression test for target_entropy='auto' should be -dim(A)."""
8535+ torch.manual_seed(self.seed)
8536+ actor = self._create_mock_actor(action_dim=action_dim)
8537+ qvalue = self._create_mock_qvalue(action_dim=action_dim)
8538+
8539+ loss_fn = CQLLoss(
8540+ actor_network=actor,
8541+ qvalue_network=qvalue,
8542+ )
8543+ # target_entropy="auto" should compute -action_dim
8544+ assert (
8545+ loss_fn.target_entropy.item() == -action_dim
8546+ ), f"target_entropy should be -{action_dim}, got {loss_fn.target_entropy.item()}"
8547+
83818548 @pytest.mark.parametrize("n", range(1, 4))
83828549 @pytest.mark.parametrize("delay_actor", (True, False))
83838550 @pytest.mark.parametrize("delay_qvalue", (True, False))
@@ -12390,6 +12557,18 @@ def test_odt_state_dict(self, device):
1239012557 loss_fn2 = OnlineDTLoss(actor)
1239112558 loss_fn2.load_state_dict(sd)
1239212559
12560+ @pytest.mark.parametrize("action_dim", [1, 2, 4, 8])
12561+ def test_odt_target_entropy_auto(self, action_dim):
12562+ """Regression test for target_entropy='auto' should be -dim(A)."""
12563+ torch.manual_seed(self.seed)
12564+ actor = self._create_mock_actor(action_dim=action_dim)
12565+
12566+ loss_fn = OnlineDTLoss(actor)
12567+ # target_entropy="auto" should compute -action_dim
12568+ assert (
12569+ loss_fn.target_entropy.item() == -action_dim
12570+ ), f"target_entropy should be -{action_dim}, got {loss_fn.target_entropy.item()}"
12571+
1239312572 @pytest.mark.parametrize("device", get_available_devices())
1239412573 def test_seq_odt(self, device):
1239512574 torch.manual_seed(self.seed)
0 commit comments