4
4
# LICENSE file in the root directory of this source tree.
5
5
from __future__ import annotations
6
6
7
+ import copy
8
+
7
9
import warnings
8
10
from typing import Any , Callable , Iterator
9
11
@@ -55,6 +57,19 @@ class RayLLMCollector(LLMCollector):
55
57
or its subclass, responsible for updating the policy weights on remote inference workers.
56
58
ray_init_config (dict[str, Any], optional): keyword arguments to pass to ray.init().
57
59
remote_config (dict[str, Any], optional): keyword arguments to pass to cls.as_remote().
60
+ sync_iter (bool, optional): if `True`, items yeilded by the collector will be synced to the local process.
61
+ If `False`, the collector will collect the next batch of data in between yielding.
62
+ This has no effect when data is collected through the :meth:`start` method.
63
+ For example:
64
+
65
+ >>> collector = RayLLMCollector(..., sync_iter=True)
66
+ >>> for data in collector: # blocking
67
+ ... # expensive operation - collector is idle
68
+ >>> collector = RayLLMCollector(..., sync_iter=False)
69
+ >>> for data in collector: # non-blocking
70
+ ... # expensive operation - collector is collecting data
71
+
72
+ Defaults to `True`.
58
73
verbose (bool, optional): if ``True``, the collector will print progress information.
59
74
Defaults to `False`.
60
75
"""
@@ -81,6 +96,7 @@ def __init__(
81
96
ray_init_config : dict [str , Any ] | None = None ,
82
97
remote_config : dict [str , Any ] | None = None ,
83
98
track_policy_version : bool | PolicyVersion = False ,
99
+ sync_iter : bool = True ,
84
100
verbose : bool = False ,
85
101
) -> None :
86
102
if not _has_ray :
@@ -93,8 +109,11 @@ def __init__(
93
109
94
110
ray_init_config = DEFAULT_RAY_INIT_CONFIG
95
111
ray .init (** ray_init_config )
96
-
112
+ if not sync_iter :
113
+ remote_config = copy .copy (remote_config )
114
+ remote_config .setdefault ("max_concurrency" , 2 )
97
115
remote_cls = LLMCollector .as_remote (remote_config ).remote
116
+ self .sync_iter = sync_iter
98
117
self ._collector = remote_cls (
99
118
env = env ,
100
119
policy = policy ,
@@ -113,19 +132,31 @@ def __init__(
113
132
verbose = verbose ,
114
133
)
115
134
135
+ def _next_remote (self ) -> None :
136
+ return self ._collector .next .remote ()
137
+
116
138
def next (self ) -> None :
117
139
"""Get the next batch of data from the collector.
118
140
119
141
Returns:
120
142
None as the data is written directly to the replay buffer.
121
143
"""
122
- return ray .get (self ._collector . next . remote ())
144
+ return ray .get (self ._next_remote ())
123
145
124
146
def __iter__ (self ) -> Iterator [None ]:
125
147
"""Returns an iterator that yields None as the collector writes directly to the replay buffer."""
148
+ if not self .sync_iter :
149
+ future = self ._next_remote ()
150
+ else :
151
+ future = None
126
152
while True :
127
153
try :
128
- yield self .next ()
154
+ if self .sync_iter :
155
+ yield self .next ()
156
+ else :
157
+ result = ray .get (future )
158
+ future = self ._next_remote ()
159
+ yield result
129
160
except StopIteration :
130
161
break
131
162
0 commit comments