5
5
6
6
from __future__ import annotations
7
7
8
+ import asyncio
8
9
import warnings
9
10
from typing import Callable , Iterator , OrderedDict
10
11
21
22
SyncDataCollector ,
22
23
)
23
24
from torchrl .collectors .utils import _NON_NN_POLICY_WEIGHTS , split_trajectories
25
+ from torchrl .data import ReplayBuffer
24
26
from torchrl .envs .common import EnvBase
25
27
from torchrl .envs .env_creator import EnvCreator
26
28
@@ -256,6 +258,11 @@ class RayCollector(DataCollectorBase):
256
258
parameters being updated for a certain time even if ``update_after_each_batch``
257
259
is turned on.
258
260
Defaults to -1 (no forced update).
261
+ replay_buffer (RayReplayBuffer, optional): if provided, the collector will not yield tensordicts
262
+ but populate the buffer instead. Defaults to ``None``.
263
+
264
+ .. note:: although it is not enfoced (to allow users to implement their own replay buffer class), a
265
+ :class:`~torchrl.data.RayReplayBuffer` instance should be used here.
259
266
260
267
Examples:
261
268
>>> from torch import nn
@@ -312,7 +319,9 @@ def __init__(
312
319
num_collectors : int = None ,
313
320
update_after_each_batch = False ,
314
321
max_weight_update_interval = - 1 ,
322
+ replay_buffer : ReplayBuffer = None ,
315
323
):
324
+ self .frames_per_batch = frames_per_batch
316
325
if remote_configs is None :
317
326
remote_configs = DEFAULT_REMOTE_CLASS_CONFIG
318
327
@@ -321,6 +330,14 @@ def __init__(
321
330
322
331
if collector_kwargs is None :
323
332
collector_kwargs = {}
333
+ if replay_buffer is not None :
334
+ if isinstance (collector_kwargs , dict ):
335
+ collector_kwargs .setdefault ("replay_buffer" , replay_buffer )
336
+ else :
337
+ collector_kwargs = [
338
+ ck .setdefault ("replay_buffer" , replay_buffer )
339
+ for ck in collector_kwargs
340
+ ]
324
341
325
342
# Make sure input parameters are consistent
326
343
def check_consistency_with_num_collectors (param , param_name , num_collectors ):
@@ -386,7 +403,8 @@ def check_list_length_consistency(*lists):
386
403
raise RuntimeError (
387
404
"ray library not found, unable to create a DistributedCollector. "
388
405
) from RAY_ERR
389
- ray .init (** ray_init_config )
406
+ if not ray .is_initialized ():
407
+ ray .init (** ray_init_config )
390
408
if not ray .is_initialized ():
391
409
raise RuntimeError ("Ray could not be initialized." )
392
410
@@ -400,6 +418,7 @@ def check_list_length_consistency(*lists):
400
418
collector_class .as_remote = as_remote
401
419
collector_class .print_remote_collector_info = print_remote_collector_info
402
420
421
+ self .replay_buffer = replay_buffer
403
422
self ._local_policy = policy
404
423
if isinstance (self ._local_policy , nn .Module ):
405
424
policy_weights = TensorDict .from_module (self ._local_policy )
@@ -557,7 +576,7 @@ def add_collectors(
557
576
policy ,
558
577
other_params ,
559
578
)
560
- self ._remote_collectors .extend ([ collector ] )
579
+ self ._remote_collectors .append ( collector )
561
580
562
581
def local_policy (self ):
563
582
"""Returns local collector."""
@@ -577,17 +596,33 @@ def stop_remote_collectors(self):
577
596
) # This will interrupt any running tasks on the actor, causing them to fail immediately
578
597
579
598
def iterator (self ):
599
+ def proc (data ):
600
+ if self .split_trajs :
601
+ data = split_trajectories (data )
602
+ if self .postproc is not None :
603
+ data = self .postproc (data )
604
+ return data
605
+
580
606
if self ._sync :
581
- data = self ._sync_iterator ()
607
+ meth = self ._sync_iterator
582
608
else :
583
- data = self ._async_iterator ()
609
+ meth = self ._async_iterator
610
+ yield from (proc (data ) for data in meth ())
584
611
585
- if self .split_trajs :
586
- data = split_trajectories (data )
587
- if self .postproc is not None :
588
- data = self .postproc (data )
612
+ async def _asyncio_iterator (self ):
613
+ def proc (data ):
614
+ if self .split_trajs :
615
+ data = split_trajectories (data )
616
+ if self .postproc is not None :
617
+ data = self .postproc (data )
618
+ return data
589
619
590
- return data
620
+ if self ._sync :
621
+ for d in self ._sync_iterator ():
622
+ yield proc (d )
623
+ else :
624
+ for d in self ._async_iterator ():
625
+ yield proc (d )
591
626
592
627
def _sync_iterator (self ) -> Iterator [TensorDictBase ]:
593
628
"""Collects one data batch per remote collector in each iteration."""
@@ -634,7 +669,30 @@ def _sync_iterator(self) -> Iterator[TensorDictBase]:
634
669
):
635
670
self .update_policy_weights_ (rank )
636
671
637
- self .shutdown ()
672
+ if self ._task is None :
673
+ self .shutdown ()
674
+
675
+ _task = None
676
+
677
+ def start (self ):
678
+ """Starts the RayCollector."""
679
+ if self .replay_buffer is None :
680
+ raise RuntimeError ("Replay buffer must be defined for asyncio execution." )
681
+ if self ._task is None or self ._task .done ():
682
+ loop = asyncio .get_event_loop ()
683
+ self ._task = loop .create_task (self ._run_iterator_silently ())
684
+
685
+ async def _run_iterator_silently (self ):
686
+ async for _ in self ._asyncio_iterator ():
687
+ # Process each item silently
688
+ continue
689
+
690
+ async def async_shutdown (self ):
691
+ """Finishes processes started by ray.init() during async execution."""
692
+ if self ._task is not None :
693
+ await self ._task
694
+ self .stop_remote_collectors ()
695
+ ray .shutdown ()
638
696
639
697
def _async_iterator (self ) -> Iterator [TensorDictBase ]:
640
698
"""Collects a data batch from a single remote collector in each iteration."""
@@ -658,7 +716,7 @@ def _async_iterator(self) -> Iterator[TensorDictBase]:
658
716
ray .internal .free (
659
717
[future ]
660
718
) # should not be necessary, deleted automatically when ref count is down to 0
661
- self .collected_frames += out_td . numel ()
719
+ self .collected_frames += self . frames_per_batch
662
720
663
721
yield out_td
664
722
@@ -689,8 +747,8 @@ def _async_iterator(self) -> Iterator[TensorDictBase]:
689
747
# object_ref=ref,
690
748
# force=False,
691
749
# )
692
-
693
- self .shutdown ()
750
+ if self . _task is None :
751
+ self .shutdown ()
694
752
695
753
def update_policy_weights_ (self , worker_rank = None ) -> None :
696
754
"""Updates the weights of the worker nodes.
0 commit comments