|
5 | 5 | from __future__ import annotations |
6 | 6 |
|
7 | 7 | import math |
| 8 | +import warnings |
8 | 9 | from dataclasses import dataclass |
9 | 10 |
|
10 | 11 | import torch |
@@ -85,6 +86,7 @@ def __init__( |
85 | 86 | target_entropy: str | float = "auto", |
86 | 87 | samples_mc_entropy: int = 1, |
87 | 88 | reduction: str | None = None, |
| 89 | + scalar_output_mode: str | None = None, |
88 | 90 | ) -> None: |
89 | 91 | self._in_keys = None |
90 | 92 | self._out_keys = None |
@@ -158,6 +160,22 @@ def __init__( |
158 | 160 | self._set_in_keys() |
159 | 161 | self.reduction = reduction |
160 | 162 |
|
| 163 | + # Handle scalar_output_mode for reduction="none" |
| 164 | + if reduction == "none" and scalar_output_mode is None: |
| 165 | + warnings.warn( |
| 166 | + "OnlineDTLoss with reduction='none' cannot include scalar values (alpha, entropy) " |
| 167 | + "in the output TensorDict without changing their shape. These values will be " |
| 168 | + "excluded from the output. You can access alpha via `loss_module.alpha` and " |
| 169 | + "compute entropy from the actor distribution. " |
| 170 | + "To suppress this warning, pass `scalar_output_mode='exclude'` to the constructor. " |
| 171 | + "Alternatively, pass `scalar_output_mode='non_tensor'` to include them as non-tensor data. " |
| 172 | + "This is a known limitation we're working on improving.", |
| 173 | + category=UserWarning, |
| 174 | + stacklevel=2, |
| 175 | + ) |
| 176 | + scalar_output_mode = "exclude" |
| 177 | + self.scalar_output_mode = scalar_output_mode |
| 178 | + |
161 | 179 | def _set_in_keys(self): |
162 | 180 | keys = self.actor_network.in_keys |
163 | 181 | keys = set(keys) |
@@ -230,15 +248,24 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: |
230 | 248 | "loss_log_likelihood": -log_likelihood, |
231 | 249 | "loss_entropy": -entropy_bonus, |
232 | 250 | "loss_alpha": loss_alpha, |
233 | | - "entropy": entropy.detach().mean(), |
234 | | - "alpha": self.alpha.detach(), |
235 | 251 | } |
236 | | - td_out = TensorDict(out, []) |
237 | | - td_out = td_out.named_apply( |
238 | | - lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1) |
239 | | - if name.startswith("loss_") |
240 | | - else value, |
241 | | - ) |
| 252 | + |
| 253 | + # Handle batch_size and scalar values (alpha, entropy) based on reduction mode |
| 254 | + if self.reduction == "none": |
| 255 | + batch_size = tensordict.batch_size |
| 256 | + td_out = TensorDict(out, batch_size=batch_size) |
| 257 | + if self.scalar_output_mode == "non_tensor": |
| 258 | + td_out.set_non_tensor("alpha", self.alpha.detach()) |
| 259 | + td_out.set_non_tensor("entropy", entropy.detach().mean()) |
| 260 | + else: |
| 261 | + out["entropy"] = entropy.detach().mean() |
| 262 | + out["alpha"] = self.alpha.detach() |
| 263 | + td_out = TensorDict(out, []) |
| 264 | + td_out = td_out.named_apply( |
| 265 | + lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1) |
| 266 | + if name.startswith("loss_") |
| 267 | + else value, |
| 268 | + ) |
242 | 269 | self._clear_weakrefs( |
243 | 270 | tensordict, |
244 | 271 | td_out, |
|
0 commit comments