Skip to content

Commit 4c6d2fc

Browse files
committed
[Feat][1/N] support async_rl in replaybuffer by supporting expired storage
1 parent 6cff996 commit 4c6d2fc

File tree

2 files changed

+412
-122
lines changed

2 files changed

+412
-122
lines changed

xtuner/v1/ray/dataflow/flow.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,37 @@ class DataFlowConfig(BaseModel):
6868
int,
6969
Parameter(help="Target number of samples to collect before stopping."),
7070
] = 8
71-
enable_partial_rollout: Annotated[
72-
int, Parameter(help="Whether to enable async rollout_controller. 1 for enabled, 0 for disabled")
73-
] = 0
7471
sample_params: Annotated[SampleParams, Parameter(help="Parameters for sampling from the model.")] = SampleParams()
7572
extra_params: Annotated[Dict, Parameter(help="Extra parameters for rollout.")] = {}
73+
# async params
74+
staleness_threshold: Annotated[
75+
float,
76+
Parameter(
77+
help="The maximum allowed threshold of stale (expired) samples in a training batch. Must be between 0.0 and 1.0."
78+
),
79+
] = 0.0
80+
enable_partial_rollout: Annotated[
81+
bool,
82+
Parameter(help="Whether to enable partial rollout for asynchronous data generation."),
83+
] = False
84+
tail_batch_candidate_steps: Annotated[
85+
int,
86+
Parameter(
87+
help="Number of rollout steps after which a sample becomes a candidate for the tail batch. Set to 0 to disable."
88+
),
89+
] = 0
90+
tail_batch_trigger_size: Annotated[
91+
Optional[int],
92+
Parameter(
93+
help="Number of candidate samples needed in the queue to trigger a tail batch operation. Set to 0 to disable."
94+
),
95+
] = None
7696
worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir"
7797

7898
def model_post_init(self, __context: Any) -> None:
7999
self.worker_log_dir.mkdir(parents=True, exist_ok=True)
100+
if self.tail_batch_trigger_size is None:
101+
self.tail_batch_trigger_size = self.global_batch_size
80102

81103

82104
@ray.remote
@@ -106,17 +128,20 @@ def __init__(
106128
postprocessor (Optional[Callable]): An optional function to
107129
post-process the generated samples.
108130
"""
131+
self.logger = get_logger(log_dir=dataflow_cfg.worker_log_dir, tag="DataFlow")
109132
self.env = env
110133
self.config = dataflow_cfg
111134
replay_buffer_cfg.worker_log_dir = self.config.worker_log_dir
135+
replay_buffer_cfg.enable_partial_rollout = self.config.enable_partial_rollout
136+
replay_buffer_cfg.tail_batch_candidate_steps = self.config.tail_batch_candidate_steps
137+
replay_buffer_cfg.tail_batch_trigger_size = self.config.tail_batch_trigger_size
112138
self.replay_buffer = ReplayBuffer.remote(replay_buffer_cfg) # type: ignore[attr-defined]
113139
self.env_controller = environment
114140
self.finished_samples_count = 0
115141
self.skipped_sample_count = 0
116142
self.failed_sample_count = 0
117-
self.logger = get_logger(log_dir=self.config.worker_log_dir, tag="DataFlow")
143+
self.filtered_samples_count = 0
118144
self.target_batch_size = self.config.global_batch_size
119-
self.logger.info(f"DataFlowConfig:\n{self.config.model_dump_json(indent=2)}")
120145
rollout_info = ray.get(self.env_controller.get_rollout_info.remote()) # type: ignore[attr-defined]
121146
self.worker_url_list = list(rollout_info["server_url_dict"].values())
122147
self.logger.info(f"DataFlow connected to active rollout workers url: {self.worker_url_list}")
@@ -140,28 +165,33 @@ def __init__(
140165
f"Dataflow max_concurrent is set to {self.config.max_concurrent}, we proposed to set max_concurrent to {max_concurrent} based on rollout_max_batch_size_per_instance."
141166
)
142167
self.enable_partial_rollout = self.config.enable_partial_rollout
168+
self.logger.info(f"DataFlowConfig:\n{self.config.model_dump_json(indent=2)}")
143169

144-
def _reset_internal_states(
170+
def _prepare(
145171
self,
146172
global_batch_size: Optional[int] = None,
147173
sample_params: Optional[SampleParams] = None,
148174
extra_params: Optional[Dict] = None,
149175
):
150176
"""Resets all internal state variables of DataFlow."""
151-
self.finished_samples_count = 0
152-
self.skipped_sample_count = 0
153-
self.failed_sample_count = 0
154177
if global_batch_size and global_batch_size > 0:
155178
self.target_batch_size = global_batch_size
156179
else:
157180
self.target_batch_size = self.config.global_batch_size
158181

182+
self.sample_from_expired_storage, self.finished_samples_count = ray.get(
183+
self.replay_buffer.get_prerun_state.remote(self.target_batch_size)
184+
)
185+
self.skipped_sample_count = 0
186+
self.failed_sample_count = 0
187+
self.filtered_samples_count = 0
188+
159189
self.sample_params = sample_params if sample_params else self.config.sample_params
160190
self.extra_params = extra_params if extra_params else self.config.extra_params
161191
logger_msg = (
162-
f"DataFlow internal states reset for new run: target_batch_size={self.target_batch_size}, "
192+
f"DataFlow states for new generations: target_batch_size={self.target_batch_size}, "
163193
f"sample_params: {self.sample_params}, extra_params: {self.extra_params}, "
164-
f"enable_partial_rollout={self.enable_partial_rollout}."
194+
f"sample_from_expired_storage={self.sample_from_expired_storage}, finished_samples_count={self.finished_samples_count}, "
165195
)
166196
self.logger.info(logger_msg)
167197

@@ -191,7 +221,7 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
191221
# TODO(@duanyanhui): More fine-grained control over group data generation:
192222
# Pass n to the inference engine to ensure that the same data is processed by the same server, improving efficiency.
193223
group_data_items = await self.replay_buffer.sample.remote( # type: ignore[attr-defined]
194-
self.env, self.enable_partial_rollout, self.config.prompt_repeat_k
224+
self.env, self.config.prompt_repeat_k
195225
)
196226
assert len(group_data_items) > 0, "Sampled empty group data items from replay buffer."
197227
action_id = group_data_items[0].uid.action_id
@@ -207,6 +237,8 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
207237
group_data_items = await self.replay_buffer.post_processor.remote(group_data_items) # type: ignore[attr-defined]
208238
if len(group_data_items) > 0:
209239
await self.replay_buffer.add.remote(group_data_items) # type: ignore[attr-defined]
240+
else:
241+
self.filtered_samples_count += 1
210242
self.logger.debug(f"Worker task completed successfully for {action_id}.")
211243
elif group_state == RolloutState.ABORTED:
212244
await self.replay_buffer.add.remote(group_data_items) # type: ignore[attr-defined]
@@ -257,15 +289,15 @@ async def concurrent_task_runner(self):
257289
while len(waiting_tasks) < self.config.max_concurrent:
258290
# In async mode, we keep spawning. In sync mode, we stop if we have enough tasks in flight.
259291
if (
260-
not self.config.enable_partial_rollout
292+
not self.enable_partial_rollout
261293
and self.finished_samples_count + len(waiting_tasks) >= self.target_batch_size
262294
):
263295
break
264296
task = create_task(self.worker_task())
265297
waiting_tasks.add(task)
266298

267299
_, pending_tasks = await asyncio.wait(waiting_tasks, timeout=0.1, return_when=asyncio.FIRST_COMPLETED)
268-
self.finished_samples_count = ray.get(self.replay_buffer.get_finished_samples.remote())
300+
self.finished_samples_count = ray.get(self.replay_buffer.get_completed_samples_count.remote())
269301
waiting_tasks = pending_tasks
270302

271303
pbar.n = self.finished_samples_count
@@ -340,7 +372,7 @@ async def run(
340372
Returns:
341373
List[RLDataFlowItem]: A list of collected training samples.
342374
"""
343-
self._reset_internal_states(global_batch_size=num, sample_params=sample_params, extra_params=extra_params)
375+
self._prepare(global_batch_size=num, sample_params=sample_params, extra_params=extra_params)
344376

345377
if resume:
346378
assert resume_path, "Resuming is enabled but no resume path is provided."
@@ -357,7 +389,13 @@ async def run(
357389
return await self.replay_buffer.get_samples.remote(self.target_batch_size) # type: ignore[attr-defined]
358390

359391
def logging_replaybuffer_state(self):
360-
ray.get(self.replay_buffer.print.remote())
392+
status = self.get_replaybuffer_status()
393+
logging_msg = f"ReplayBuffer Status: {status}"
394+
logging_msg += f", finished_samples_count: {self.finished_samples_count}, "
395+
logging_msg += f"skipped_samples_count: {self.skipped_sample_count}, "
396+
logging_msg += f"failed_samples_count: {self.failed_sample_count}, "
397+
logging_msg += f"filtered_samples_count: {self.filtered_samples_count}, "
398+
self.logger.info(logging_msg)
361399

362400
def get_replaybuffer_status(self):
363401
return ray.get(self.replay_buffer.status.remote())

0 commit comments

Comments
 (0)