Skip to content

Commit 01413ca

Browse files
authored
[BugFix] Add scalar_output_mode to loss modules for reduction='none' (#3426)
1 parent 1b68722 commit 01413ca

File tree

8 files changed

+297
-34
lines changed

8 files changed

+297
-34
lines changed

.github/RELEASE_AGENT_PROMPT.md

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,31 @@ Get commits from the last release:
4242
git log v0.11.0..HEAD --oneline --no-merges
4343
```
4444

45+
**Important: PR Selection for Minor Releases**
46+
47+
When selecting PRs for a minor release, follow this decision flow:
48+
49+
1. **If labeled `user-facing`****Exclude** (only for major releases)
50+
2. **If labeled `non-user-facing` or `Suitable for minor`****Include**
51+
3. **If neither label is present****Assess yourself** based on the changes
52+
53+
Labels:
54+
- `user-facing` - API changes, new features, or public interface changes
55+
- `non-user-facing` - Internal changes, bug fixes, refactoring
56+
- `Suitable for minor` - Explicitly marked as safe for minor releases
57+
58+
To filter PRs:
59+
```bash
60+
# Find PRs explicitly safe for minor release
61+
gh pr list --label "non-user-facing" --state merged --json number,title
62+
gh pr list --label "Suitable for minor" --state merged --json number,title
63+
64+
# Check labels on a specific PR
65+
gh pr view <PR_NUMBER> --json labels --jq '.labels[].name'
66+
```
67+
68+
For unlabeled PRs, review the changes and determine if they affect the public API or just internal implementation.
69+
4570
### Critical: Don't Miss ghstack Commits
4671

4772
**The biggest pitfall in release notes is only looking at commits with PR numbers.** Many of the most significant features are merged via ghstack and have NO PR number in the commit message. Always analyze both:
@@ -477,9 +502,11 @@ After completing all steps, provide this summary to the user:
477502
## Version Naming Convention
478503

479504
- **Major releases**: `v0.11.0`, `v0.12.0` - New features, may have breaking changes
480-
- **Minor/Patch releases**: `v0.11.1`, `v0.11.2` - Bug fixes, no new features
505+
- **Minor/Patch releases**: `v0.11.1`, `v0.11.2` - Bug fixes only, no new features or user-facing changes
481506
- **Release candidates**: `v0.11.0-rc1` - Pre-release testing
482507

508+
**Note:** PRs labeled `user-facing` must only be included in major releases, never in minor/patch releases.
509+
483510
## TensorDict Version Compatibility
484511

485512
TorchRL and TensorDict versions must match in major version:

test/test_objectives.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5376,6 +5376,7 @@ def test_sac_reduction(self, reduction, version, composite_action_dist):
53765376
delay_value=False,
53775377
reduction=reduction,
53785378
action_spec=action_spec,
5379+
scalar_output_mode="exclude" if reduction == "none" else None,
53795380
)
53805381
loss_fn.make_value_estimator()
53815382
loss = loss_fn(td)
@@ -6259,6 +6260,7 @@ def test_discrete_sac_reduction(self, reduction):
62596260
action_space="one-hot",
62606261
delay_qvalue=False,
62616262
reduction=reduction,
6263+
scalar_output_mode="exclude" if reduction == "none" else None,
62626264
)
62636265
loss_fn.make_value_estimator()
62646266
loss = loss_fn(td)
@@ -7052,6 +7054,7 @@ def test_crossq_reduction(self, reduction):
70527054
qvalue_network=qvalue,
70537055
loss_function="l2",
70547056
reduction=reduction,
7057+
scalar_output_mode="exclude" if reduction == "none" else None,
70557058
)
70567059
loss_fn.make_value_estimator()
70577060
loss = loss_fn(td)
@@ -8043,6 +8046,7 @@ def test_redq_reduction(self, reduction, deprecated_loss):
80438046
loss_function="l2",
80448047
delay_qvalue=False,
80458048
reduction=reduction,
8049+
scalar_output_mode="exclude" if reduction == "none" else None,
80468050
)
80478051
loss_fn.make_value_estimator()
80488052
loss = loss_fn(td)
@@ -8706,6 +8710,7 @@ def test_cql_reduction(self, reduction):
87068710
delay_actor=False,
87078711
delay_qvalue=False,
87088712
reduction=reduction,
8713+
scalar_output_mode="exclude" if reduction == "none" else None,
87098714
)
87108715
loss_fn.make_value_estimator()
87118716
loss = loss_fn(td)
@@ -12677,7 +12682,11 @@ def test_onlinedt_reduction(self, reduction):
1267712682
)
1267812683
td = self._create_mock_data_odt(device=device)
1267912684
actor = self._create_mock_actor(device=device)
12680-
loss_fn = OnlineDTLoss(actor, reduction=reduction)
12685+
loss_fn = OnlineDTLoss(
12686+
actor,
12687+
reduction=reduction,
12688+
scalar_output_mode="exclude" if reduction == "none" else None,
12689+
)
1268112690
loss = loss_fn(td)
1268212691
if reduction == "none":
1268312692
for key in loss.keys():
@@ -13983,6 +13992,7 @@ def test_iql_reduction(self, reduction):
1398313992
value_network=value,
1398413993
loss_function="l2",
1398513994
reduction=reduction,
13995+
scalar_output_mode="exclude" if reduction == "none" else None,
1398613996
)
1398713997
loss_fn.make_value_estimator()
1398813998
with _check_td_steady(td), pytest.warns(
@@ -14815,6 +14825,7 @@ def test_discrete_iql_reduction(self, reduction):
1481514825
loss_function="l2",
1481614826
action_space="one-hot",
1481714827
reduction=reduction,
14828+
scalar_output_mode="exclude" if reduction == "none" else None,
1481814829
)
1481914830
loss_fn.make_value_estimator()
1482014831
with _check_td_steady(td), pytest.warns(

torchrl/objectives/cql.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ def __init__(
298298
lagrange_thresh: float = 0.0,
299299
reduction: str | None = None,
300300
deactivate_vmap: bool = False,
301+
scalar_output_mode: str | None = None,
301302
) -> None:
302303
self._out_keys = None
303304
if reduction is None:
@@ -381,6 +382,23 @@ def __init__(
381382
)
382383
self._make_vmap()
383384
self.reduction = reduction
385+
386+
# Handle scalar_output_mode for reduction="none"
387+
if reduction == "none" and scalar_output_mode is None:
388+
warnings.warn(
389+
"CQLLoss with reduction='none' cannot include scalar values (alpha, entropy) "
390+
"in the output TensorDict without changing their shape. These values will be "
391+
"excluded from the output. You can access them via `loss_module._alpha` and "
392+
"compute entropy from the log_prob in the actor loss metadata. "
393+
"To suppress this warning, pass `scalar_output_mode='exclude'` to the constructor. "
394+
"Alternatively, pass `scalar_output_mode='non_tensor'` to include them as non-tensor data. "
395+
"This is a known limitation we're working on improving.",
396+
category=UserWarning,
397+
stacklevel=2,
398+
)
399+
scalar_output_mode = "exclude"
400+
self.scalar_output_mode = scalar_output_mode
401+
384402
_ = self.target_entropy
385403

386404
def _make_vmap(self):
@@ -548,18 +566,28 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
548566
tensordict.set(
549567
self.tensor_keys.priority, metadata.pop("td_error").detach().max(0).values
550568
)
569+
entropy = -actor_metadata.get(self.tensor_keys.log_prob)
551570
out = {
552571
"loss_actor": loss_actor,
553572
"loss_actor_bc": loss_actor_bc,
554573
"loss_qvalue": q_loss,
555574
"loss_cql": cql_loss,
556575
"loss_alpha": loss_alpha,
557-
"alpha": self._alpha,
558-
"entropy": -actor_metadata.get(self.tensor_keys.log_prob).mean().detach(),
559576
}
560577
if self.with_lagrange:
561578
out["loss_alpha_prime"] = alpha_prime_loss.mean()
562-
td_loss = TensorDict(out)
579+
580+
# Handle batch_size and scalar values (alpha, entropy) based on reduction mode
581+
if self.reduction == "none":
582+
batch_size = tensordict.batch_size
583+
td_loss = TensorDict(out, batch_size=batch_size)
584+
if self.scalar_output_mode == "non_tensor":
585+
td_loss.set_non_tensor("alpha", self._alpha)
586+
td_loss.set_non_tensor("entropy", entropy.detach().mean())
587+
else:
588+
out["alpha"] = self._alpha
589+
out["entropy"] = entropy.detach().mean()
590+
td_loss = TensorDict(out)
563591
self._clear_weakrefs(
564592
tensordict,
565593
td_loss,

torchrl/objectives/crossq.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import math
8+
import warnings
89
from dataclasses import dataclass
910
from functools import wraps
1011

@@ -274,6 +275,7 @@ def __init__(
274275
separate_losses: bool = False,
275276
reduction: str | None = None,
276277
deactivate_vmap: bool = False,
278+
scalar_output_mode: str | None = None,
277279
) -> None:
278280
self._in_keys = None
279281
self._out_keys = None
@@ -348,6 +350,23 @@ def __init__(
348350
self._action_spec = action_spec
349351
self._make_vmap()
350352
self.reduction = reduction
353+
354+
# Handle scalar_output_mode for reduction="none"
355+
if reduction == "none" and scalar_output_mode is None:
356+
warnings.warn(
357+
"CrossQLoss with reduction='none' cannot include scalar values (alpha, entropy) "
358+
"in the output TensorDict without changing their shape. These values will be "
359+
"excluded from the output. You can access them via `loss_module._alpha` and "
360+
"compute entropy from the log_prob in the actor loss metadata. "
361+
"To suppress this warning, pass `scalar_output_mode='exclude'` to the constructor. "
362+
"Alternatively, pass `scalar_output_mode='non_tensor'` to include them as non-tensor data. "
363+
"This is a known limitation we're working on improving.",
364+
category=UserWarning,
365+
stacklevel=2,
366+
)
367+
scalar_output_mode = "exclude"
368+
self.scalar_output_mode = scalar_output_mode
369+
351370
# init target entropy
352371
self.maybe_init_target_entropy()
353372

@@ -553,12 +572,21 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
553572
"loss_actor": loss_actor,
554573
"loss_qvalue": loss_qvalue,
555574
"loss_alpha": loss_alpha,
556-
"alpha": self._alpha,
557-
"entropy": entropy.detach().mean(),
558575
**metadata_actor,
559576
**value_metadata,
560577
}
561-
td_out = TensorDict(out)
578+
579+
# Handle batch_size and scalar values (alpha, entropy) based on reduction mode
580+
if self.reduction == "none":
581+
batch_size = tensordict.batch_size
582+
td_out = TensorDict(out, batch_size=batch_size)
583+
if self.scalar_output_mode == "non_tensor":
584+
td_out.set_non_tensor("alpha", self._alpha)
585+
td_out.set_non_tensor("entropy", entropy.detach().mean())
586+
else:
587+
out["alpha"] = self._alpha
588+
out["entropy"] = entropy.detach().mean()
589+
td_out = TensorDict(out)
562590
self._clear_weakrefs(
563591
tensordict,
564592
td_out,

torchrl/objectives/decision_transformer.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import math
8+
import warnings
89
from dataclasses import dataclass
910

1011
import torch
@@ -85,6 +86,7 @@ def __init__(
8586
target_entropy: str | float = "auto",
8687
samples_mc_entropy: int = 1,
8788
reduction: str | None = None,
89+
scalar_output_mode: str | None = None,
8890
) -> None:
8991
self._in_keys = None
9092
self._out_keys = None
@@ -158,6 +160,22 @@ def __init__(
158160
self._set_in_keys()
159161
self.reduction = reduction
160162

163+
# Handle scalar_output_mode for reduction="none"
164+
if reduction == "none" and scalar_output_mode is None:
165+
warnings.warn(
166+
"OnlineDTLoss with reduction='none' cannot include scalar values (alpha, entropy) "
167+
"in the output TensorDict without changing their shape. These values will be "
168+
"excluded from the output. You can access alpha via `loss_module.alpha` and "
169+
"compute entropy from the actor distribution. "
170+
"To suppress this warning, pass `scalar_output_mode='exclude'` to the constructor. "
171+
"Alternatively, pass `scalar_output_mode='non_tensor'` to include them as non-tensor data. "
172+
"This is a known limitation we're working on improving.",
173+
category=UserWarning,
174+
stacklevel=2,
175+
)
176+
scalar_output_mode = "exclude"
177+
self.scalar_output_mode = scalar_output_mode
178+
161179
def _set_in_keys(self):
162180
keys = self.actor_network.in_keys
163181
keys = set(keys)
@@ -230,15 +248,24 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
230248
"loss_log_likelihood": -log_likelihood,
231249
"loss_entropy": -entropy_bonus,
232250
"loss_alpha": loss_alpha,
233-
"entropy": entropy.detach().mean(),
234-
"alpha": self.alpha.detach(),
235251
}
236-
td_out = TensorDict(out, [])
237-
td_out = td_out.named_apply(
238-
lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
239-
if name.startswith("loss_")
240-
else value,
241-
)
252+
253+
# Handle batch_size and scalar values (alpha, entropy) based on reduction mode
254+
if self.reduction == "none":
255+
batch_size = tensordict.batch_size
256+
td_out = TensorDict(out, batch_size=batch_size)
257+
if self.scalar_output_mode == "non_tensor":
258+
td_out.set_non_tensor("alpha", self.alpha.detach())
259+
td_out.set_non_tensor("entropy", entropy.detach().mean())
260+
else:
261+
out["entropy"] = entropy.detach().mean()
262+
out["alpha"] = self.alpha.detach()
263+
td_out = TensorDict(out, [])
264+
td_out = td_out.named_apply(
265+
lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
266+
if name.startswith("loss_")
267+
else value,
268+
)
242269
self._clear_weakrefs(
243270
tensordict,
244271
td_out,

0 commit comments

Comments
 (0)