|
27 | 27 | from torchrl._utils import RL_WARNINGS |
28 | 28 | from torchrl.envs.utils import step_mdp |
29 | 29 |
|
30 | | -from torchrl.objectives.utils import _vmap_func, hold_out_net |
| 30 | +from torchrl.objectives.utils import _vmap_func, hold_out_net, RANDOM_MODULE_LIST |
31 | 31 | from torchrl.objectives.value.functional import ( |
32 | 32 | generalized_advantage_estimate, |
33 | 33 | td0_return_estimate, |
@@ -78,6 +78,7 @@ def _call_value_nets( |
78 | 78 | single_call: bool, |
79 | 79 | value_key: NestedKey, |
80 | 80 | detach_next: bool, |
| 81 | + vmap_randomness: str = "error", |
81 | 82 | ): |
82 | 83 | in_keys = value_net.in_keys |
83 | 84 | if single_call: |
@@ -141,9 +142,11 @@ def _call_value_nets( |
141 | 142 | ) |
142 | 143 | elif params is not None: |
143 | 144 | params_stack = torch.stack([params, next_params], 0).contiguous() |
144 | | - data_out = _vmap_func(value_net, (0, 0))(data_in, params_stack) |
| 145 | + data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)( |
| 146 | + data_in, params_stack |
| 147 | + ) |
145 | 148 | else: |
146 | | - data_out = vmap(value_net, (0,))(data_in) |
| 149 | + data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in) |
147 | 150 | value_est = data_out.get(value_key) |
148 | 151 | value, value_ = value_est[0], value_est[1] |
149 | 152 | data.set(value_key, value) |
@@ -214,6 +217,7 @@ class _AcceptedKeys: |
214 | 217 |
|
215 | 218 | default_keys = _AcceptedKeys() |
216 | 219 | value_network: Union[TensorDictModule, Callable] |
| 220 | + _vmap_randomness = None |
217 | 221 |
|
218 | 222 | @property |
219 | 223 | def advantage_key(self): |
@@ -428,6 +432,28 @@ def _next_value(self, tensordict, target_params, kwargs): |
428 | 432 | next_value = step_td.get(self.tensor_keys.value) |
429 | 433 | return next_value |
430 | 434 |
|
| 435 | + @property |
| 436 | + def vmap_randomness(self): |
| 437 | + if self._vmap_randomness is None: |
| 438 | + do_break = False |
| 439 | + for val in self.__dict__.values(): |
| 440 | + if isinstance(val, torch.nn.Module): |
| 441 | + for module in val.modules(): |
| 442 | + if isinstance(module, RANDOM_MODULE_LIST): |
| 443 | + self._vmap_randomness = "different" |
| 444 | + do_break = True |
| 445 | + break |
| 446 | + if do_break: |
| 447 | + # double break |
| 448 | + break |
| 449 | + else: |
| 450 | + self._vmap_randomness = "error" |
| 451 | + |
| 452 | + return self._vmap_randomness |
| 453 | + |
| 454 | + def set_vmap_randomness(self, value): |
| 455 | + self._vmap_randomness = value |
| 456 | + |
431 | 457 |
|
432 | 458 | class TD0Estimator(ValueEstimatorBase): |
433 | 459 | """Temporal Difference (TD(0)) estimate of advantage function. |
@@ -589,6 +615,7 @@ def forward( |
589 | 615 | single_call=self.shifted, |
590 | 616 | value_key=self.tensor_keys.value, |
591 | 617 | detach_next=True, |
| 618 | + vmap_randomness=self.vmap_randomness, |
592 | 619 | ) |
593 | 620 | else: |
594 | 621 | value = tensordict.get(self.tensor_keys.value) |
@@ -790,6 +817,7 @@ def forward( |
790 | 817 | single_call=self.shifted, |
791 | 818 | value_key=self.tensor_keys.value, |
792 | 819 | detach_next=True, |
| 820 | + vmap_randomness=self.vmap_randomness, |
793 | 821 | ) |
794 | 822 | else: |
795 | 823 | value = tensordict.get(self.tensor_keys.value) |
@@ -1001,6 +1029,7 @@ def forward( |
1001 | 1029 | single_call=self.shifted, |
1002 | 1030 | value_key=self.tensor_keys.value, |
1003 | 1031 | detach_next=True, |
| 1032 | + vmap_randomness=self.vmap_randomness, |
1004 | 1033 | ) |
1005 | 1034 | else: |
1006 | 1035 | value = tensordict.get(self.tensor_keys.value) |
@@ -1247,6 +1276,7 @@ def forward( |
1247 | 1276 | single_call=self.shifted, |
1248 | 1277 | value_key=self.tensor_keys.value, |
1249 | 1278 | detach_next=True, |
| 1279 | + vmap_randomness=self.vmap_randomness, |
1250 | 1280 | ) |
1251 | 1281 | else: |
1252 | 1282 | value = tensordict.get(self.tensor_keys.value) |
@@ -1329,6 +1359,7 @@ def value_estimate( |
1329 | 1359 | single_call=self.shifted, |
1330 | 1360 | value_key=self.tensor_keys.value, |
1331 | 1361 | detach_next=True, |
| 1362 | + vmap_randomness=self.vmap_randomness, |
1332 | 1363 | ) |
1333 | 1364 | else: |
1334 | 1365 | value = tensordict.get(self.tensor_keys.value) |
@@ -1575,6 +1606,7 @@ def forward( |
1575 | 1606 | single_call=self.shifted, |
1576 | 1607 | value_key=self.tensor_keys.value, |
1577 | 1608 | detach_next=True, |
| 1609 | + vmap_randomness=self.vmap_randomness, |
1578 | 1610 | ) |
1579 | 1611 | else: |
1580 | 1612 | value = tensordict.get(self.tensor_keys.value) |
|
0 commit comments