57
57
clear_mpi_env_vars ,
58
58
)
59
59
60
+ _CONSOLIDATE_ERR_CAPTURE = (
61
+ "TensorDict.consolidate failed. You can deactivate the tensordict consolidation via the "
62
+ "`consolidate` keyword argument of the ParallelEnv constructor."
63
+ )
64
+
60
65
61
66
def _check_start (fun ):
62
67
def decorated_fun (self : BatchedEnvBase , * args , ** kwargs ):
@@ -307,6 +312,7 @@ def __init__(
307
312
non_blocking : bool = False ,
308
313
mp_start_method : str = None ,
309
314
use_buffers : bool = None ,
315
+ consolidate : bool = True ,
310
316
):
311
317
super ().__init__ (device = device )
312
318
self .serial_for_single = serial_for_single
@@ -315,6 +321,7 @@ def __init__(
315
321
self .num_threads = num_threads
316
322
self ._cache_in_keys = None
317
323
self ._use_buffers = use_buffers
324
+ self .consolidate = consolidate
318
325
319
326
self ._single_task = callable (create_env_fn ) or (len (set (create_env_fn )) == 1 )
320
327
if callable (create_env_fn ):
@@ -841,9 +848,12 @@ def __repr__(self) -> str:
841
848
f"\n \t batch_size={ self .batch_size } )"
842
849
)
843
850
844
- def close (self ) -> None :
851
+ def close (self , * , raise_if_closed : bool = True ) -> None :
845
852
if self .is_closed :
846
- raise RuntimeError ("trying to close a closed environment" )
853
+ if raise_if_closed :
854
+ raise RuntimeError ("trying to close a closed environment" )
855
+ else :
856
+ return
847
857
if self ._verbose :
848
858
torchrl_logger .info (f"closing { self .__class__ .__name__ } " )
849
859
@@ -1470,6 +1480,12 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
1470
1480
"_non_tensor_keys" : self ._non_tensor_keys ,
1471
1481
}
1472
1482
)
1483
+ else :
1484
+ kwargs [idx ].update (
1485
+ {
1486
+ "consolidate" : self .consolidate ,
1487
+ }
1488
+ )
1473
1489
process = proc_fun (target = func , kwargs = kwargs [idx ])
1474
1490
process .daemon = True
1475
1491
process .start ()
@@ -1526,7 +1542,16 @@ def _step_and_maybe_reset_no_buffers(
1526
1542
else :
1527
1543
workers_range = range (self .num_workers )
1528
1544
1529
- td = tensordict .consolidate (share_memory = True , inplace = True , num_threads = 1 )
1545
+ if self .consolidate :
1546
+ try :
1547
+ td = tensordict .consolidate (
1548
+ share_memory = True , inplace = True , num_threads = 1
1549
+ )
1550
+ except Exception as err :
1551
+ raise RuntimeError (_CONSOLIDATE_ERR_CAPTURE ) from err
1552
+ else :
1553
+ td = tensordict
1554
+
1530
1555
for i in workers_range :
1531
1556
# We send the same td multiple times as it is in shared mem and we just need to index it
1532
1557
# in each process.
@@ -1804,7 +1829,16 @@ def _step_no_buffers(
1804
1829
else :
1805
1830
workers_range = range (self .num_workers )
1806
1831
1807
- data = tensordict .consolidate (share_memory = True , inplace = True , num_threads = 1 )
1832
+ if self .consolidate :
1833
+ try :
1834
+ data = tensordict .consolidate (
1835
+ share_memory = True , inplace = True , num_threads = 1
1836
+ )
1837
+ except Exception as err :
1838
+ raise RuntimeError (_CONSOLIDATE_ERR_CAPTURE ) from err
1839
+ else :
1840
+ data = tensordict
1841
+
1808
1842
for i , local_data in zip (workers_range , data .unbind (0 )):
1809
1843
self .parent_channels [i ].send (("step" , local_data ))
1810
1844
# for i in range(data.shape[0]):
@@ -2026,9 +2060,14 @@ def _reset_no_buffers(
2026
2060
) -> Tuple [TensorDictBase , TensorDictBase ]:
2027
2061
if is_tensor_collection (tensordict ):
2028
2062
# tensordict = tensordict.consolidate(share_memory=True, num_threads=1)
2029
- tensordict = tensordict .consolidate (
2030
- share_memory = True , num_threads = 1
2031
- ).unbind (0 )
2063
+ if self .consolidate :
2064
+ try :
2065
+ tensordict = tensordict .consolidate (
2066
+ share_memory = True , num_threads = 1
2067
+ )
2068
+ except Exception as err :
2069
+ raise RuntimeError (_CONSOLIDATE_ERR_CAPTURE ) from err
2070
+ tensordict = tensordict .unbind (0 )
2032
2071
else :
2033
2072
tensordict = [None ] * self .num_workers
2034
2073
out_tds = [None ] * self .num_workers
@@ -2545,6 +2584,7 @@ def _run_worker_pipe_direct(
2545
2584
has_lazy_inputs : bool = False ,
2546
2585
verbose : bool = False ,
2547
2586
num_threads : int | None = None , # for fork start method
2587
+ consolidate : bool = True ,
2548
2588
) -> None :
2549
2589
if num_threads is not None :
2550
2590
torch .set_num_threads (num_threads )
@@ -2634,9 +2674,18 @@ def _run_worker_pipe_direct(
2634
2674
event .record ()
2635
2675
event .synchronize ()
2636
2676
mp_event .set ()
2637
- child_pipe .send (
2638
- cur_td .consolidate (share_memory = True , inplace = True , num_threads = 1 )
2639
- )
2677
+ if consolidate :
2678
+ try :
2679
+ child_pipe .send (
2680
+ cur_td .consolidate (
2681
+ share_memory = True , inplace = True , num_threads = 1
2682
+ )
2683
+ )
2684
+ except Exception as err :
2685
+ raise RuntimeError (_CONSOLIDATE_ERR_CAPTURE ) from err
2686
+ else :
2687
+ child_pipe .send (cur_td )
2688
+
2640
2689
del cur_td
2641
2690
2642
2691
elif cmd == "step" :
@@ -2650,9 +2699,18 @@ def _run_worker_pipe_direct(
2650
2699
event .record ()
2651
2700
event .synchronize ()
2652
2701
mp_event .set ()
2653
- child_pipe .send (
2654
- next_td .consolidate (share_memory = True , inplace = True , num_threads = 1 )
2655
- )
2702
+ if consolidate :
2703
+ try :
2704
+ child_pipe .send (
2705
+ next_td .consolidate (
2706
+ share_memory = True , inplace = True , num_threads = 1
2707
+ )
2708
+ )
2709
+ except Exception as err :
2710
+ raise RuntimeError (_CONSOLIDATE_ERR_CAPTURE ) from err
2711
+ else :
2712
+ child_pipe .send (next_td )
2713
+
2656
2714
del next_td
2657
2715
2658
2716
elif cmd == "step_and_maybe_reset" :
0 commit comments