Skip to content

Commit 508aac2

Browse files
author
Juan de los Rios
committed
reduce complexity by requiring full path for entropy coeffs
1 parent 1eb0bee commit 508aac2

File tree

2 files changed

+12
-31
lines changed

2 files changed

+12
-31
lines changed

test/test_cost.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9779,11 +9779,11 @@ def mixture_constructor(logits, loc, scale):
97799779
# keep per-head entropies instead of the aggregated tensor
97809780
set_composite_lp_aggregate(False).set()
97819781
coef_map = {
9782-
("agent0", "action", "action1", "sub_action1_log_prob"): 0.10,
9783-
"sub_action2_log_prob": 0.10,
9784-
"action2": 0.10,
9785-
("agent1", "action_log_prob"): 0.10,
9786-
"agent2_log_prob": 0.02,
9782+
("agent0", "action", "action1", "sub_action1_log_prob"):0.02,
9783+
("agent0", "action", "action1", "sub_action2_log_prob"):0.01,
9784+
("agent0", "action", "action2_log_prob"):0.01,
9785+
("agent1", "action_log_prob"):0.01,
9786+
"agent2_log_prob":0.01,
97879787
}
97889788
ppo_weighted = cls(policy, value_operator, entropy_coeff=coef_map)
97899789
ppo_weighted.set_keys(
@@ -9872,7 +9872,7 @@ def test_weighted_entropy_scalar(self):
98729872
torch.testing.assert_close(out, torch.tensor(-1.0))
98739873

98749874
def test_weighted_entropy_mapping(self):
9875-
coef = {"head_0": 0.3, "head_1": 0.7}
9875+
coef = {("head_0","action_log_prob"): 0.3, ("head_1","action_log_prob"): 0.7}
98769876
loss = self._make_entropy_loss(entropy_coeff=coef)
98779877
entropy = TensorDict(
98789878
{
@@ -9882,7 +9882,7 @@ def test_weighted_entropy_mapping(self):
98829882
[],
98839883
)
98849884
out = loss._weighted_loss_entropy(entropy)
9885-
expected = -(coef["head_0"] * 1.0 + coef["head_1"] * 2.0)
9885+
expected = -(coef[("head_0","action_log_prob")] * 1.0 + coef[("head_1","action_log_prob")] * 2.0)
98869886
torch.testing.assert_close(out, torch.tensor(expected))
98879887

98889888
def test_weighted_entropy_mapping_missing_key(self):

torchrl/objectives/ppo.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -923,33 +923,14 @@ def _weighted_loss_entropy(
923923
for head_name, entropy_head in entropy.items(
924924
include_nested=True, leaves_only=True
925925
):
926-
if isinstance(head_name, str):
927-
head_name = (head_name,)
928-
for i, (head_name_from_map, _coeff) in enumerate(
929-
self._entropy_coeff_map.items()
930-
):
931-
# Check if distinct head name inisde tuple of nested dict
932-
if head_name_from_map in head_name:
933-
coeff = _coeff
934-
break
935-
# Check if path of head fully or partially in nested dict
936-
if any(
937-
head_name_from_map == head_name[i : i + len(head_name_from_map)]
938-
for i in range(len(head_name) - len(head_name_from_map) + 1)
939-
):
940-
coeff = _coeff
941-
break
942-
if i == len(self._entropy_coeff_map.items()):
943-
raise KeyError(
944-
f"Missing entropy coeff for head '{head_name}'"
945-
) from exec
926+
try:
927+
coeff = self._entropy_coeff_map[head_name]
928+
except KeyError as exc:
929+
raise KeyError(f"Missing entropy coeff for head '{head_name}'") from exc
946930
coeff_t = torch.as_tensor(
947931
coeff, dtype=entropy_head.dtype, device=entropy_head.device
948932
)
949-
if isinstance(entropy_head, torch.Tensor):
950-
head_loss_term = -coeff_t * entropy_head
951-
else:
952-
head_loss_term = -coeff_t * _sum_td_features(entropy_head)
933+
head_loss_term = -coeff_t * entropy_head
953934
loss_term = (
954935
head_loss_term if loss_term is None else loss_term + head_loss_term
955936
) # accumulate

0 commit comments

Comments
 (0)