Skip to content

Commit 77c00b9

Browse files
authored
[Feature] Enable LineariseRewards to work with negative weights (#3064)
1 parent 523ba2e commit 77c00b9

File tree

2 files changed

+22
-12
lines changed

2 files changed

+22
-12
lines changed

test/test_transforms.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13778,9 +13778,8 @@ def test_weight_shape_error(self):
1377813778
):
1377913779
LineariseRewards(in_keys=("reward",), weights=torch.ones(size=(2, 4)))
1378013780

13781-
def test_weight_sign_error(self):
13782-
with pytest.raises(ValueError, match="Expected all weights to be >0"):
13783-
LineariseRewards(in_keys=("reward",), weights=-torch.ones(size=(2,)))
13781+
def test_weight_no_sign_error(self):
13782+
LineariseRewards(in_keys=("reward",), weights=-torch.ones(size=(2,)))
1378413783

1378513784
def test_discrete_spec_error(self):
1378613785
with pytest.raises(
@@ -13980,6 +13979,7 @@ def _set_seed(self, seed: int | None = None) -> None:
1398013979
(1, None),
1398113980
(3, None),
1398213981
(2, [1.0, 2.0]),
13982+
(2, [1.0, -1.0]),
1398313983
],
1398413984
)
1398513985
def test_transform_env(self, num_rewards, weights):
@@ -14062,6 +14062,15 @@ def test_transform_inverse(self):
1406214062
),
1406314063
BoundedContinuous(low=-1.0, high=1.0, shape=1),
1406414064
),
14065+
(
14066+
[1.0, -1.0],
14067+
BoundedContinuous(
14068+
low=[-1.0, -2.0],
14069+
high=[1.0, 2.0],
14070+
shape=2,
14071+
),
14072+
BoundedContinuous(low=-3.0, high=3.0, shape=1),
14073+
),
1406514074
],
1406614075
)
1406714076
def test_reward_spec(

torchrl/envs/transforms/transforms.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10748,10 +10748,6 @@ def __init__(
1074810748
f"Expected weights to be a unidimensional tensor. Got {weights.ndim} dimension."
1074910749
)
1075010750

10751-
# Avoids switching from reward to costs.
10752-
if (weights < 0).any():
10753-
raise ValueError(f"Expected all weights to be >0. Got {weights}.")
10754-
1075510751
self.register_buffer("weights", weights)
1075610752
else:
1075710753
self.weights = None
@@ -10781,13 +10777,18 @@ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
1078110777
reward_spec.shape = torch.Size([*batch_size, 1])
1078210778
return reward_spec
1078310779

10784-
# The lines below are correct only if all weights are positive.
10785-
low = (weights * reward_spec.space.low).sum(dim=-1, keepdim=True)
10786-
high = (weights * reward_spec.space.high).sum(dim=-1, keepdim=True)
10780+
weights_pos = weights.clamp(min=0)
10781+
weights_neg = weights.clamp(max=0)
10782+
10783+
low_pos = (weights_pos * reward_spec.space.low).sum(dim=-1, keepdim=True)
10784+
low_neg = (weights_neg * reward_spec.space.high).sum(dim=-1, keepdim=True)
10785+
10786+
high_pos = (weights_pos * reward_spec.space.high).sum(dim=-1, keepdim=True)
10787+
high_neg = (weights_neg * reward_spec.space.low).sum(dim=-1, keepdim=True)
1078710788

1078810789
return BoundedContinuous(
10789-
low=low,
10790-
high=high,
10790+
low=low_pos + low_neg,
10791+
high=high_pos + high_neg,
1079110792
device=reward_spec.device,
1079210793
dtype=reward_spec.dtype,
1079310794
)

0 commit comments

Comments
 (0)