@@ -68,15 +68,37 @@ class DataFlowConfig(BaseModel):
6868 int ,
6969 Parameter (help = "Target number of samples to collect before stopping." ),
7070 ] = 8
71- enable_partial_rollout : Annotated [
72- int , Parameter (help = "Whether to enable async rollout_controller. 1 for enabled, 0 for disabled" )
73- ] = 0
7471 sample_params : Annotated [SampleParams , Parameter (help = "Parameters for sampling from the model." )] = SampleParams ()
7572 extra_params : Annotated [Dict , Parameter (help = "Extra parameters for rollout." )] = {}
73+ # async params
74+ staleness_threshold : Annotated [
75+ float ,
76+ Parameter (
77+ help = "The maximum allowed threshold of stale (expired) samples in a training batch. Must be between 0.0 and 1.0."
78+ ),
79+ ] = 0.0
80+ enable_partial_rollout : Annotated [
81+ bool ,
82+ Parameter (help = "Whether to enable partial rollout for asynchronous data generation." ),
83+ ] = False
84+ tail_batch_candidate_steps : Annotated [
85+ int ,
86+ Parameter (
87+ help = "Number of rollout steps after which a sample becomes a candidate for the tail batch. Set to 0 to disable."
88+ ),
89+ ] = 0
90+ tail_batch_trigger_size : Annotated [
91+ Optional [int ],
92+ Parameter (
93+ help = "Number of candidate samples needed in the queue to trigger a tail batch operation. Set to 0 to disable."
94+ ),
95+ ] = None
7696 worker_log_dir : Annotated [Path , Parameter (help = "Directory to save worker logs." )] = Path .cwd () / "work_dir"
7797
7898 def model_post_init (self , __context : Any ) -> None :
7999 self .worker_log_dir .mkdir (parents = True , exist_ok = True )
100+ if self .tail_batch_trigger_size is None :
101+ self .tail_batch_trigger_size = self .global_batch_size
80102
81103
82104@ray .remote
@@ -106,17 +128,20 @@ def __init__(
106128 postprocessor (Optional[Callable]): An optional function to
107129 post-process the generated samples.
108130 """
131+ self .logger = get_logger (log_dir = dataflow_cfg .worker_log_dir , tag = "DataFlow" )
109132 self .env = env
110133 self .config = dataflow_cfg
111134 replay_buffer_cfg .worker_log_dir = self .config .worker_log_dir
135+ replay_buffer_cfg .enable_partial_rollout = self .config .enable_partial_rollout
136+ replay_buffer_cfg .tail_batch_candidate_steps = self .config .tail_batch_candidate_steps
137+ replay_buffer_cfg .tail_batch_trigger_size = self .config .tail_batch_trigger_size
112138 self .replay_buffer = ReplayBuffer .remote (replay_buffer_cfg ) # type: ignore[attr-defined]
113139 self .env_controller = environment
114140 self .finished_samples_count = 0
115141 self .skipped_sample_count = 0
116142 self .failed_sample_count = 0
117- self .logger = get_logger ( log_dir = self . config . worker_log_dir , tag = "DataFlow" )
143+ self .filtered_samples_count = 0
118144 self .target_batch_size = self .config .global_batch_size
119- self .logger .info (f"DataFlowConfig:\n { self .config .model_dump_json (indent = 2 )} " )
120145 rollout_info = ray .get (self .env_controller .get_rollout_info .remote ()) # type: ignore[attr-defined]
121146 self .worker_url_list = list (rollout_info ["server_url_dict" ].values ())
122147 self .logger .info (f"DataFlow connected to active rollout workers url: { self .worker_url_list } " )
@@ -140,28 +165,33 @@ def __init__(
140165 f"Dataflow max_concurrent is set to { self .config .max_concurrent } , we proposed to set max_concurrent to { max_concurrent } based on rollout_max_batch_size_per_instance."
141166 )
142167 self .enable_partial_rollout = self .config .enable_partial_rollout
168+ self .logger .info (f"DataFlowConfig:\n { self .config .model_dump_json (indent = 2 )} " )
143169
144- def _reset_internal_states (
170+ def _prepare (
145171 self ,
146172 global_batch_size : Optional [int ] = None ,
147173 sample_params : Optional [SampleParams ] = None ,
148174 extra_params : Optional [Dict ] = None ,
149175 ):
150176 """Resets all internal state variables of DataFlow."""
151- self .finished_samples_count = 0
152- self .skipped_sample_count = 0
153- self .failed_sample_count = 0
154177 if global_batch_size and global_batch_size > 0 :
155178 self .target_batch_size = global_batch_size
156179 else :
157180 self .target_batch_size = self .config .global_batch_size
158181
182+ self .sample_from_expired_storage , self .finished_samples_count = ray .get (
183+ self .replay_buffer .get_prerun_state .remote (self .target_batch_size )
184+ )
185+ self .skipped_sample_count = 0
186+ self .failed_sample_count = 0
187+ self .filtered_samples_count = 0
188+
159189 self .sample_params = sample_params if sample_params else self .config .sample_params
160190 self .extra_params = extra_params if extra_params else self .config .extra_params
161191 logger_msg = (
162- f"DataFlow internal states reset for new run : target_batch_size={ self .target_batch_size } , "
192+ f"DataFlow states for new generations : target_batch_size={ self .target_batch_size } , "
163193 f"sample_params: { self .sample_params } , extra_params: { self .extra_params } , "
164- f"enable_partial_rollout ={ self .enable_partial_rollout } . "
194+ f"sample_from_expired_storage ={ self .sample_from_expired_storage } , finished_samples_count= { self . finished_samples_count } , "
165195 )
166196 self .logger .info (logger_msg )
167197
@@ -191,7 +221,7 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
191221 # TODO(@duanyanhui): More fine-grained control over group data generation:
192222 # Pass n to the inference engine to ensure that the same data is processed by the same server, improving efficiency.
193223 group_data_items = await self .replay_buffer .sample .remote ( # type: ignore[attr-defined]
194- self .env , self .enable_partial_rollout , self . config .prompt_repeat_k
224+ self .env , self .config .prompt_repeat_k
195225 )
196226 assert len (group_data_items ) > 0 , "Sampled empty group data items from replay buffer."
197227 action_id = group_data_items [0 ].uid .action_id
@@ -207,6 +237,8 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
207237 group_data_items = await self .replay_buffer .post_processor .remote (group_data_items ) # type: ignore[attr-defined]
208238 if len (group_data_items ) > 0 :
209239 await self .replay_buffer .add .remote (group_data_items ) # type: ignore[attr-defined]
240+ else :
241+ self .filtered_samples_count += 1
210242 self .logger .debug (f"Worker task completed successfully for { action_id } ." )
211243 elif group_state == RolloutState .ABORTED :
212244 await self .replay_buffer .add .remote (group_data_items ) # type: ignore[attr-defined]
@@ -257,15 +289,15 @@ async def concurrent_task_runner(self):
257289 while len (waiting_tasks ) < self .config .max_concurrent :
258290 # In async mode, we keep spawning. In sync mode, we stop if we have enough tasks in flight.
259291 if (
260- not self .config . enable_partial_rollout
292+ not self .enable_partial_rollout
261293 and self .finished_samples_count + len (waiting_tasks ) >= self .target_batch_size
262294 ):
263295 break
264296 task = create_task (self .worker_task ())
265297 waiting_tasks .add (task )
266298
267299 _ , pending_tasks = await asyncio .wait (waiting_tasks , timeout = 0.1 , return_when = asyncio .FIRST_COMPLETED )
268- self .finished_samples_count = ray .get (self .replay_buffer .get_finished_samples .remote ())
300+ self .finished_samples_count = ray .get (self .replay_buffer .get_completed_samples_count .remote ())
269301 waiting_tasks = pending_tasks
270302
271303 pbar .n = self .finished_samples_count
@@ -340,7 +372,7 @@ async def run(
340372 Returns:
341373 List[RLDataFlowItem]: A list of collected training samples.
342374 """
343- self ._reset_internal_states (global_batch_size = num , sample_params = sample_params , extra_params = extra_params )
375+ self ._prepare (global_batch_size = num , sample_params = sample_params , extra_params = extra_params )
344376
345377 if resume :
346378 assert resume_path , "Resuming is enabled but no resume path is provided."
@@ -357,7 +389,13 @@ async def run(
357389 return await self .replay_buffer .get_samples .remote (self .target_batch_size ) # type: ignore[attr-defined]
358390
359391 def logging_replaybuffer_state (self ):
360- ray .get (self .replay_buffer .print .remote ())
392+ status = self .get_replaybuffer_status ()
393+ logging_msg = f"ReplayBuffer Status: { status } "
394+ logging_msg += f", finished_samples_count: { self .finished_samples_count } , "
395+ logging_msg += f"skipped_samples_count: { self .skipped_sample_count } , "
396+ logging_msg += f"failed_samples_count: { self .failed_sample_count } , "
397+ logging_msg += f"filtered_samples_count: { self .filtered_samples_count } , "
398+ self .logger .info (logging_msg )
361399
362400 def get_replaybuffer_status (self ):
363401 return ray .get (self .replay_buffer .status .remote ())
0 commit comments