Skip to content

Commit 421dd56

Browse files
authored
Merge branch 'inclusionAI:main' into refactor/rpc_framework
2 parents db035ec + a90553b commit 421dd56

File tree

7 files changed

+554
-326
lines changed

7 files changed

+554
-326
lines changed

areal/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,17 +368,17 @@ class PPOActor:
368368
def __init__(self, config: PPOActorConfig, engine: TrainEngine):
369369
self.config = config
370370
self.engine = engine
371+
self.temperature = config.temperature
371372

372373
@torch.no_grad()
373374
def compute_logp(
374375
self,
375376
data: dict[str, Any],
376-
temperature: float | None = None,
377377
) -> torch.Tensor | None:
378378

379379
def calc_logprobs(logits, input_data):
380380
labels = torch.roll(input_data["input_ids"], shifts=-1, dims=-1)
381-
logprobs = gather_logprobs(logits, labels, temperature or 1.0)
381+
logprobs = gather_logprobs(logits, labels, self.temperature)
382382
return logprobs
383383

384384
self.engine.eval()

areal/api/reward_api.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import os
23
import threading
34
import traceback
45
import weakref
@@ -12,6 +13,29 @@
1213
logger = logging.getLogger("Reward API")
1314

1415

16+
def _get_device_count_safely() -> int:
17+
"""
18+
Safely get device count without initializing CUDA context.
19+
"""
20+
gpu_types = ["nvidia", "davinci"]
21+
try:
22+
if os.path.exists("/dev"):
23+
for gpu_type in gpu_types:
24+
devices = [
25+
f
26+
for f in os.listdir("/dev")
27+
if f.startswith(gpu_type) and f[len(gpu_type) :].isdigit()
28+
]
29+
if devices:
30+
return len(devices)
31+
except (OSError, ValueError) as e:
32+
# /dev doesn't exist or can't read (e.g., Windows, macOS)
33+
logger.debug(f"Could not read device list from /dev, using fallback: {e}")
34+
35+
# Fallback: assume 8 devices for cautious max_workers calculation
36+
return 8
37+
38+
1539
def reward_fn(
1640
prompt: str,
1741
completions: str,
@@ -54,6 +78,12 @@ def __init__(
5478
):
5579
self.reward_fn = reward_fn
5680
self.timeout_seconds = timeout_seconds
81+
if max_workers is None:
82+
cpu_count = os.cpu_count() or 1
83+
device_count = _get_device_count_safely()
84+
# Heuristic for max_workers: distribute CPU cores across devices,
85+
# then halve to be conservative, ensuring at least one worker.
86+
max_workers = max((cpu_count // device_count) // 2, 1)
5787
self.max_workers = max_workers
5888
self.max_retries = max_retries
5989
self._executor_key = max_workers

areal/core/async_task_runner.py

Lines changed: 37 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import asyncio
1212
import queue
13-
import random
1413
import threading
1514
import time
1615
from collections.abc import Awaitable, Callable
@@ -24,18 +23,26 @@
2423

2524
# Polling configuration
2625
DEFAULT_POLL_WAIT_TIME = 0.05 # 50ms
27-
DEFAULT_POLL_SLEEP_TIME = 0.5 # 1 second
26+
DEFAULT_POLL_SLEEP_TIME = 0.5 # 500ms
2827

2928

3029
class TaskQueueFullError(RuntimeError):
3130
"""Raised when an AsyncTaskRunner queue is full."""
3231

3332

3433
@dataclass
35-
class _TimedResult(Generic[T]):
36-
"""Internal wrapper for results with creation timestamp."""
34+
class TimedResult(Generic[T]):
35+
"""Wrapper for task results with creation timestamp.
3736
38-
create_time: int # nanoseconds from time.monotonic_ns()
37+
Attributes
38+
----------
39+
create_time : int
40+
Task creation time in nanoseconds from time.monotonic_ns().
41+
data : T
42+
The actual result data from the completed task.
43+
"""
44+
45+
create_time: int
3946
data: T
4047

4148

@@ -72,13 +79,13 @@ class AsyncTaskRunner(Generic[T]):
7279
----------
7380
max_queue_size : int
7481
Maximum size for input and output queues. Tasks submitted when
75-
the input queue is full will raise RuntimeError.
82+
the input queue is full will raise TaskQueueFullError.
7683
poll_wait_time : float, optional
7784
Time in seconds to wait for task completion during each poll
7885
cycle. Default is 0.05 (50ms).
7986
poll_sleep_time : float, optional
8087
Time in seconds to sleep between poll cycles.
81-
Default is 1.0 second.
88+
Default is 0.5 seconds.
8289
enable_tracing : bool, optional
8390
Enable detailed logging of task submission and completion.
8491
Default is False.
@@ -163,7 +170,7 @@ def __init__(
163170
Default is 0.05.
164171
poll_sleep_time : float, optional
165172
Time in seconds to sleep between poll cycles.
166-
Default is 1.0.
173+
Default is 0.5.
167174
enable_tracing : bool, optional
168175
Enable detailed logging. Default is False.
169176
"""
@@ -180,13 +187,10 @@ def __init__(
180187
self.input_queue: queue.Queue[_TaskInput[T]] = queue.Queue(
181188
maxsize=max_queue_size
182189
)
183-
self.output_queue: queue.Queue[_TimedResult[T]] = queue.Queue(
190+
self.output_queue: queue.Queue[TimedResult[T]] = queue.Queue(
184191
maxsize=max_queue_size
185192
)
186193

187-
# Cache for results to support wait() with arbitrary counts
188-
self.result_cache: list[_TimedResult[T]] = []
189-
190194
# Thread exception handling
191195
self._thread_exception_lock = threading.Lock()
192196
self._thread_exception: Exception | None = None
@@ -335,7 +339,7 @@ async def _run_async_loop(self):
335339
try:
336340
# Place result in output queue
337341
self.output_queue.put_nowait(
338-
_TimedResult(create_time=task_obj.create_time, data=result)
342+
TimedResult(create_time=task_obj.create_time, data=result)
339343
)
340344
if self.enable_tracing and self.logger:
341345
self.logger.info(
@@ -355,6 +359,7 @@ async def _run_async_loop(self):
355359
raise TaskQueueFullError(
356360
"Output queue full. Please increase max_queue_size."
357361
)
362+
# Sleep to avoid busy-waiting
358363
await asyncio.sleep(self.poll_sleep_time)
359364
finally:
360365
# Cancel all remaining tasks on shutdown
@@ -390,9 +395,10 @@ def submit(
390395
391396
Raises
392397
------
398+
TaskQueueFullError
399+
If the input queue is full.
393400
RuntimeError
394-
If the input queue is full or if the background thread
395-
has died.
401+
If the background thread has died.
396402
397403
Examples
398404
--------
@@ -417,12 +423,13 @@ def submit(
417423
"wait for tasks to complete."
418424
)
419425

420-
def wait(self, count: int, timeout: float | None = None) -> list[T]:
426+
def wait(
427+
self, count: int, timeout: float | None = None, with_timing: bool = False
428+
) -> list[TimedResult[T]] | list[T]:
421429
"""Wait for a specified number of task results.
422430
423431
This method blocks until at least `count` results are available
424-
or the timeout expires. Results are returned in random order
425-
(shuffled).
432+
or the timeout expires.
426433
427434
Parameters
428435
----------
@@ -431,11 +438,15 @@ def wait(self, count: int, timeout: float | None = None) -> list[T]:
431438
timeout : float | None, optional
432439
Maximum time in seconds to wait. If None, waits indefinitely
433440
(up to 7 days). Default is None.
441+
with_timing : bool, optional
442+
If True, return TimedResult objects with creation timestamps.
443+
If False, return only the data values. Default is False.
434444
435445
Returns
436446
-------
437-
List[T]
438-
List of task results, shuffled randomly.
447+
list[TimedResult[T]] | list[T]
448+
If with_timing=True, returns list of TimedResult objects.
449+
If with_timing=False, returns list of result data.
439450
440451
Raises
441452
------
@@ -460,16 +471,7 @@ def wait(self, count: int, timeout: float | None = None) -> list[T]:
460471
# Check thread health
461472
self._check_thread_health()
462473

463-
# Drain all available results from output queue
464-
while True:
465-
try:
466-
timed_result = self.output_queue.get_nowait()
467-
self.result_cache.append(timed_result)
468-
except queue.Empty:
469-
break
470-
471-
# Check if we have enough results
472-
if len(self.result_cache) >= count:
474+
if self.get_output_queue_size() >= count:
473475
break
474476

475477
# Sleep briefly to avoid busy waiting
@@ -480,23 +482,16 @@ def wait(self, count: int, timeout: float | None = None) -> list[T]:
480482
self._check_thread_health()
481483
raise RuntimeError("AsyncTaskRunner is exiting, cannot wait for results.")
482484

483-
accepted = len(self.result_cache)
485+
accepted = self.get_output_queue_size()
484486
if accepted < count:
485487
raise TimeoutError(
486488
f"Timed out waiting for {count} results, only received {accepted}."
487489
)
488490

489-
# Sort by creation time for deterministic ordering
490-
self.result_cache.sort(key=lambda x: x.create_time)
491-
492-
# Extract the requested number of results
493-
results_to_return = self.result_cache[:count]
494-
self.result_cache = self.result_cache[count:]
495-
496-
# Shuffle for randomness (helps with data diversity in ML)
497-
random.shuffle(results_to_return)
498-
499-
# Extract just the data (remove timing metadata)
491+
# Extract the requested number of results, sorted by return time
492+
results_to_return = [self.output_queue.get() for _ in range(count)]
493+
if with_timing:
494+
return results_to_return
500495
return [r.data for r in results_to_return]
501496

502497
def submit_batch(

0 commit comments

Comments
 (0)