Skip to content

Commit ff2e265

Browse files
BY571vmoens
authored andcommitted
[BugFix] Vmap randomness for value estimator (#1942)
1 parent bc95cbb commit ff2e265

File tree

1 file changed

+35
-3
lines changed

1 file changed

+35
-3
lines changed

torchrl/objectives/value/advantages.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from torchrl._utils import RL_WARNINGS
2828
from torchrl.envs.utils import step_mdp
2929

30-
from torchrl.objectives.utils import _vmap_func, hold_out_net
30+
from torchrl.objectives.utils import _vmap_func, hold_out_net, RANDOM_MODULE_LIST
3131
from torchrl.objectives.value.functional import (
3232
generalized_advantage_estimate,
3333
td0_return_estimate,
@@ -78,6 +78,7 @@ def _call_value_nets(
7878
single_call: bool,
7979
value_key: NestedKey,
8080
detach_next: bool,
81+
vmap_randomness: str = "error",
8182
):
8283
in_keys = value_net.in_keys
8384
if single_call:
@@ -141,9 +142,11 @@ def _call_value_nets(
141142
)
142143
elif params is not None:
143144
params_stack = torch.stack([params, next_params], 0).contiguous()
144-
data_out = _vmap_func(value_net, (0, 0))(data_in, params_stack)
145+
data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)(
146+
data_in, params_stack
147+
)
145148
else:
146-
data_out = vmap(value_net, (0,))(data_in)
149+
data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in)
147150
value_est = data_out.get(value_key)
148151
value, value_ = value_est[0], value_est[1]
149152
data.set(value_key, value)
@@ -214,6 +217,7 @@ class _AcceptedKeys:
214217

215218
default_keys = _AcceptedKeys()
216219
value_network: Union[TensorDictModule, Callable]
220+
_vmap_randomness = None
217221

218222
@property
219223
def advantage_key(self):
@@ -428,6 +432,28 @@ def _next_value(self, tensordict, target_params, kwargs):
428432
next_value = step_td.get(self.tensor_keys.value)
429433
return next_value
430434

435+
@property
436+
def vmap_randomness(self):
437+
if self._vmap_randomness is None:
438+
do_break = False
439+
for val in self.__dict__.values():
440+
if isinstance(val, torch.nn.Module):
441+
for module in val.modules():
442+
if isinstance(module, RANDOM_MODULE_LIST):
443+
self._vmap_randomness = "different"
444+
do_break = True
445+
break
446+
if do_break:
447+
# double break
448+
break
449+
else:
450+
self._vmap_randomness = "error"
451+
452+
return self._vmap_randomness
453+
454+
def set_vmap_randomness(self, value):
455+
self._vmap_randomness = value
456+
431457

432458
class TD0Estimator(ValueEstimatorBase):
433459
"""Temporal Difference (TD(0)) estimate of advantage function.
@@ -589,6 +615,7 @@ def forward(
589615
single_call=self.shifted,
590616
value_key=self.tensor_keys.value,
591617
detach_next=True,
618+
vmap_randomness=self.vmap_randomness,
592619
)
593620
else:
594621
value = tensordict.get(self.tensor_keys.value)
@@ -790,6 +817,7 @@ def forward(
790817
single_call=self.shifted,
791818
value_key=self.tensor_keys.value,
792819
detach_next=True,
820+
vmap_randomness=self.vmap_randomness,
793821
)
794822
else:
795823
value = tensordict.get(self.tensor_keys.value)
@@ -1001,6 +1029,7 @@ def forward(
10011029
single_call=self.shifted,
10021030
value_key=self.tensor_keys.value,
10031031
detach_next=True,
1032+
vmap_randomness=self.vmap_randomness,
10041033
)
10051034
else:
10061035
value = tensordict.get(self.tensor_keys.value)
@@ -1247,6 +1276,7 @@ def forward(
12471276
single_call=self.shifted,
12481277
value_key=self.tensor_keys.value,
12491278
detach_next=True,
1279+
vmap_randomness=self.vmap_randomness,
12501280
)
12511281
else:
12521282
value = tensordict.get(self.tensor_keys.value)
@@ -1329,6 +1359,7 @@ def value_estimate(
13291359
single_call=self.shifted,
13301360
value_key=self.tensor_keys.value,
13311361
detach_next=True,
1362+
vmap_randomness=self.vmap_randomness,
13321363
)
13331364
else:
13341365
value = tensordict.get(self.tensor_keys.value)
@@ -1575,6 +1606,7 @@ def forward(
15751606
single_call=self.shifted,
15761607
value_key=self.tensor_keys.value,
15771608
detach_next=True,
1609+
vmap_randomness=self.vmap_randomness,
15781610
)
15791611
else:
15801612
value = tensordict.get(self.tensor_keys.value)

0 commit comments

Comments
 (0)