@@ -374,6 +374,14 @@ def __init__(
374374
375375 is_spec_locked = EnvBase .is_spec_locked
376376
377+ def select_and_clone (self , name , tensor , selected_keys = None ):
378+ if selected_keys is None :
379+ selected_keys = self ._selected_step_keys
380+ if name in selected_keys :
381+ if self .device is not None and tensor .device != self .device :
382+ return tensor .to (self .device , non_blocking = self .non_blocking )
383+ return tensor .clone ()
384+
377385 @property
378386 def non_blocking (self ):
379387 nb = self ._non_blocking
@@ -1062,12 +1070,10 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
10621070 selected_output_keys = self ._selected_reset_keys_filt
10631071
10641072 # select + clone creates 2 tds, but we can create one only
1065- def select_and_clone (name , tensor ):
1066- if name in selected_output_keys :
1067- return tensor .clone ()
1068-
10691073 out = self .shared_tensordict_parent .named_apply (
1070- select_and_clone ,
1074+ lambda * args : self .select_and_clone (
1075+ * args , selected_keys = selected_output_keys
1076+ ),
10711077 nested_keys = True ,
10721078 filter_empty = True ,
10731079 )
@@ -1135,14 +1141,14 @@ def _step(
11351141 # will be modified in-place at further steps
11361142 device = self .device
11371143
1138- def select_and_clone (name , tensor ):
1139- if name in self ._selected_step_keys :
1140- return tensor .clone ()
1144+ selected_keys = self ._selected_step_keys
11411145
11421146 if partial_steps is not None :
11431147 next_td = TensorDict .lazy_stack ([next_td [i ] for i in workers_range ])
11441148 out = next_td .named_apply (
1145- select_and_clone , nested_keys = True , filter_empty = True
1149+ lambda * args : self .select_and_clone (* args , selected_keys ),
1150+ nested_keys = True ,
1151+ filter_empty = True ,
11461152 )
11471153 if out_tds is not None :
11481154 out .update (
@@ -1841,20 +1847,8 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
18411847 next_td = shared_tensordict_parent .get ("next" )
18421848 device = self .device
18431849
1844- if next_td .device != device and device is not None :
1845-
1846- def select_and_clone (name , tensor ):
1847- if name in self ._selected_step_keys :
1848- return tensor .to (device , non_blocking = self .non_blocking )
1849-
1850- else :
1851-
1852- def select_and_clone (name , tensor ):
1853- if name in self ._selected_step_keys :
1854- return tensor .clone ()
1855-
18561850 out = next_td .named_apply (
1857- select_and_clone ,
1851+ self . select_and_clone ,
18581852 nested_keys = True ,
18591853 filter_empty = True ,
18601854 device = device ,
@@ -2005,20 +1999,10 @@ def tentative_update(val, other):
20051999 selected_output_keys = self ._selected_reset_keys_filt
20062000 device = self .device
20072001
2008- if self .shared_tensordict_parent .device != device and device is not None :
2009-
2010- def select_and_clone (name , tensor ):
2011- if name in selected_output_keys :
2012- return tensor .to (device , non_blocking = self .non_blocking )
2013-
2014- else :
2015-
2016- def select_and_clone (name , tensor ):
2017- if name in selected_output_keys :
2018- return tensor .clone ()
2019-
20202002 out = self .shared_tensordict_parent .named_apply (
2021- select_and_clone ,
2003+ lambda * args : self .select_and_clone (
2004+ * args , selected_keys = selected_output_keys
2005+ ),
20222006 nested_keys = True ,
20232007 filter_empty = True ,
20242008 device = device ,
0 commit comments