Skip to content

Commit 8c25b5a

Browse files
峯回楚财
authored andcommitted
PullRequest: 945 rollout succuess
Merge branch test/tmp1102 of [email protected]:inclusionAI/AReaL.git into asystem/gh https://code.alipay.com/inclusionAI/AReaL/pull_requests/945 Reviewed-by: 楚财 <[email protected]> * .
1 parent 72c8da4 commit 8c25b5a

File tree

4 files changed

+11
-2
lines changed

4 files changed

+11
-2
lines changed

areal/controller/rollout_controller.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ async def _wait_callback(self, worker: Worker):
301301
tik = time.time()
302302
while result == "NO_RESULT" and time.time() - tik < self.config.request_timeout:
303303
result = await self.scheduler.async_call_engine(
304-
worker.id, "wait_quiet", count=1, timeout=1, max_retries=1
304+
worker.id, "wait_quiet", count=1, timeout=3600, max_retries=1
305305
)
306306

307307
# The RPCServer will return None if the
@@ -402,6 +402,8 @@ def wait(self, count: int, timeout: float | None = None) -> DistributedBatch:
402402
# Check capacity before submitting
403403
capacity = self.get_capacity()
404404
# Submit pending tasks
405+
self.logger.info(f"Capacity: {capacity}, pending inputs: {len(self._pending_inputs)}")
406+
405407
for _ in range(capacity):
406408
if len(self._pending_inputs) == 0:
407409
break
@@ -544,6 +546,7 @@ def prepare_batch(
544546
# Capacity exhausted during batch submission, stop and wait
545547
break
546548
try:
549+
self.logger.info("Wait for batch...")
547550
return self.wait(dataloader.batch_size, timeout=1)
548551
except TimeoutError:
549552
pass

areal/core/staleness_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def get_capacity(self, current_version: int) -> int:
9797

9898
# Return the minimum of both constraints
9999
capacity = min(concurrency_capacity, staleness_capacity)
100+
print(f"Capacity: {capacity}, max_concurrent_rollouts: {max_concurrent_rollouts}, rollout_stat: {self.rollout_stat}, consumer_bs: {consumer_bs}")
100101
return capacity
101102

102103
def on_rollout_submitted(self) -> None:

areal/examples/configs/my001/on_policy.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ rollout:
5454
model_path: ${tokenizer_path}
5555
storage_path: "${storage_prefix}/checkpoints"
5656
seed: ${seed}
57+
max_concurrent_rollouts: 64
58+
queue_size: null
59+
consumer_batch_size: ${train_dataset.batch_size}
60+
max_head_offpolicyness: 2
5761
engine_config:
5862
attention_backend: "triton"
5963
disable_custom_all_reduce: true

areal/extension/asystem/ascheduler/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,9 @@ async def async_call_engine(
227227
EngineCallError: If the method call fails.
228228
"""
229229
logger.info(f"Async calling '{method}' on worker {worker_id}")
230+
max_retries = kwargs.pop("max_retries", 3)
230231
return await self.rpc_client.async_call_engine(
231-
worker_id, method, 3, *args, **kwargs
232+
worker_id, method, max_retries, *args, **kwargs
232233
)
233234

234235
async def create_engine(self, worker_id: str, engine: str, *args, **kwargs) -> Any:

0 commit comments

Comments
 (0)