Skip to content

Commit 6d6daa4

Browse files
committed
Update (base update)
[ghstack-poisoned]
1 parent 122bc89 commit 6d6daa4

File tree

5 files changed

+319
-453
lines changed

5 files changed

+319
-453
lines changed

test/test_inference_server.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import concurrent.futures
88
import threading
9+
import time
910

1011
import pytest
1112
import torch
@@ -606,8 +607,6 @@ def test_weight_sync_init_called(self):
606607

607608
with InferenceServer(policy, transport, weight_sync=ws):
608609
# Give the worker thread a moment to start
609-
import time
610-
611610
time.sleep(0.1)
612611
assert ws.initialized_on_receiver
613612
assert ws.synchronized_on_receiver
@@ -634,8 +633,6 @@ def test_weight_update_applied(self):
634633
ws.push(new_weights)
635634

636635
# Give the server loop a chance to apply the update
637-
import time
638-
639636
time.sleep(0.2)
640637

641638
# Now inference should reflect zero weights
@@ -662,8 +659,6 @@ def test_inference_continues_after_weight_update(self):
662659
new_weights = TensorDict.from_module(policy)
663660
ws.push(new_weights)
664661

665-
import time
666-
667662
time.sleep(0.1)
668663

669664
# Continue making requests

torchrl/modules/inference_server/_monarch.py

Lines changed: 45 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -4,103 +4,50 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7-
import queue
87
import threading
9-
import time
10-
from typing import Any
118

12-
from tensordict.base import TensorDictBase
9+
from torchrl.modules.inference_server._queue_transport import (
10+
_QueueInferenceClient,
11+
QueueBasedTransport,
12+
)
1313

14-
from torchrl.modules.inference_server._transport import InferenceTransport
1514

16-
_SENTINEL = object()
15+
class _MonarchRequestQueue:
16+
"""Wrapper around ``MonarchQueue`` that signals a :class:`threading.Event` on put.
1717
18+
Also adapts the Monarch queue API (``get(block=False)``) to the standard
19+
``get_nowait()`` expected by :class:`QueueBasedTransport`.
20+
"""
1821

19-
class _MonarchFuture:
20-
"""Future-like object for Monarch transport results.
22+
def __init__(self, monarch_queue, has_work: threading.Event):
23+
self._queue = monarch_queue
24+
self._has_work = has_work
2125

22-
Args:
23-
client: the :class:`_MonarchInferenceClient` that created this future.
24-
req_id: the unique request identifier within that client.
25-
"""
26+
def put(self, item):
27+
self._queue.put(item)
28+
self._has_work.set()
2629

27-
def __init__(self, client: _MonarchInferenceClient, req_id: int):
28-
self._client = client
29-
self._req_id = req_id
30-
self._result: Any = _SENTINEL
30+
def get(self, timeout=None):
31+
return self._queue.get(timeout=timeout)
3132

32-
def done(self) -> bool:
33-
"""Return ``True`` if the result is available without blocking."""
34-
if self._result is not _SENTINEL:
35-
return True
36-
try:
37-
self._result = self._client._get_result(self._req_id, timeout=0)
38-
except queue.Empty:
39-
return False
40-
return True
41-
42-
def result(self, timeout: float | None = None) -> TensorDictBase:
43-
"""Block until the result is available."""
44-
if self._result is _SENTINEL:
45-
self._result = self._client._get_result(self._req_id, timeout=timeout)
46-
if isinstance(self._result, BaseException):
47-
raise self._result
48-
return self._result
49-
50-
51-
class _MonarchInferenceClient:
52-
"""Actor-side client for :class:`MonarchTransport`.
53-
54-
Each client owns a dedicated response queue and routes results by
55-
request-id.
56-
57-
Args:
58-
request_queue: the shared Monarch queue for requests.
59-
response_queue: this client's dedicated response queue.
60-
actor_id: the unique identifier assigned by the transport.
61-
"""
33+
def get_nowait(self):
34+
return self._queue.get(block=False)
35+
36+
37+
class _MonarchResponseQueue:
38+
"""Thin wrapper adapting the MonarchQueue get API."""
39+
40+
def __init__(self, monarch_queue):
41+
self._queue = monarch_queue
6242

63-
def __init__(self, request_queue, response_queue, actor_id: int):
64-
self._request_queue = request_queue
65-
self._response_queue = response_queue
66-
self._actor_id = actor_id
67-
self._next_req_id = 0
68-
self._buffered: dict[int, Any] = {}
69-
70-
def __call__(self, td: TensorDictBase) -> TensorDictBase:
71-
"""Submit a request and block until the result is ready."""
72-
return self.submit(td).result()
73-
74-
def submit(self, td: TensorDictBase) -> _MonarchFuture:
75-
"""Submit a request and return a :class:`_MonarchFuture`."""
76-
req_id = self._next_req_id
77-
self._next_req_id += 1
78-
self._request_queue.put((self._actor_id, req_id, td))
79-
return _MonarchFuture(self, req_id)
80-
81-
# -- internal -------------------------------------------------------------
82-
83-
def _get_result(self, req_id: int, timeout: float | None = None) -> Any:
84-
"""Return the result for *req_id*, buffering any earlier arrivals."""
85-
if req_id in self._buffered:
86-
return self._buffered.pop(req_id)
87-
deadline = None if timeout is None else time.monotonic() + timeout
88-
while True:
89-
remaining = None
90-
if deadline is not None:
91-
remaining = deadline - time.monotonic()
92-
if remaining <= 0:
93-
raise queue.Empty(f"Timeout waiting for result of request {req_id}")
94-
try:
95-
rid, result = self._response_queue.get(timeout=remaining)
96-
except Exception:
97-
raise queue.Empty(f"Timeout waiting for result of request {req_id}")
98-
if rid == req_id:
99-
return result
100-
self._buffered[rid] = result
101-
102-
103-
class MonarchTransport(InferenceTransport):
43+
def put(self, item):
44+
self._queue.put(item)
45+
46+
def get(self, timeout=None):
47+
return self._queue.get(timeout=timeout)
48+
49+
50+
class MonarchTransport(QueueBasedTransport):
10451
"""Transport using Monarch for distributed inference on GPU clusters.
10552
10653
Uses Monarch's actor model and RDMA-capable channels for efficient
@@ -118,6 +65,7 @@ class MonarchTransport(InferenceTransport):
11865
"""
11966

12067
def __init__(self, *, max_queue_size: int = 1000):
68+
super().__init__()
12169
try:
12270
import monarch # noqa: F401
12371
from monarch.tools.queue import MonarchQueue
@@ -126,66 +74,21 @@ def __init__(self, *, max_queue_size: int = 1000):
12674
"Monarch is required for MonarchTransport. "
12775
"Install it following the Monarch documentation."
12876
)
129-
self._request_queue = MonarchQueue(maxsize=max_queue_size)
130-
self._response_queues: dict[int, Any] = {}
131-
self._lock = threading.Lock()
132-
self._next_actor_id = 0
77+
self._has_work = threading.Event()
78+
self._request_queue = _MonarchRequestQueue(
79+
MonarchQueue(maxsize=max_queue_size), self._has_work
80+
)
81+
self._response_queues: dict[int, _MonarchResponseQueue] = {}
13382
self._MonarchQueue = MonarchQueue
13483

135-
# -- actor API ------------------------------------------------------------
84+
def _make_response_queue(self) -> _MonarchResponseQueue:
85+
return _MonarchResponseQueue(self._MonarchQueue(maxsize=1000))
13686

137-
def client(self) -> _MonarchInferenceClient:
87+
def client(self) -> _QueueInferenceClient:
13888
"""Create an actor-side client with a dedicated response queue.
13989
14090
Returns:
141-
A :class:`_MonarchInferenceClient` that can be passed to a Monarch
91+
A :class:`_QueueInferenceClient` that can be passed to a Monarch
14292
actor.
14393
"""
144-
with self._lock:
145-
actor_id = self._next_actor_id
146-
self._next_actor_id += 1
147-
response_queue = self._MonarchQueue(maxsize=1000)
148-
self._response_queues[actor_id] = response_queue
149-
return _MonarchInferenceClient(self._request_queue, response_queue, actor_id)
150-
151-
def submit(self, td: TensorDictBase):
152-
"""Not supported -- use :meth:`client` to obtain an actor handle."""
153-
raise RuntimeError(
154-
"MonarchTransport.submit() is not supported. "
155-
"Call transport.client() to create a _MonarchInferenceClient."
156-
)
157-
158-
# -- server API -----------------------------------------------------------
159-
160-
def drain(
161-
self, max_items: int
162-
) -> tuple[list[TensorDictBase], list[tuple[int, int]]]:
163-
"""Dequeue up to *max_items* pending requests (non-blocking)."""
164-
items: list[TensorDictBase] = []
165-
callbacks: list[tuple[int, int]] = []
166-
for _ in range(max_items):
167-
try:
168-
actor_id, req_id, td = self._request_queue.get(block=False)
169-
items.append(td)
170-
callbacks.append((actor_id, req_id))
171-
except Exception:
172-
break
173-
return items, callbacks
174-
175-
def wait_for_work(self, timeout: float) -> None:
176-
"""Block until at least one request is available or *timeout* elapses."""
177-
try:
178-
item = self._request_queue.get(timeout=timeout)
179-
self._request_queue.put(item)
180-
except Exception:
181-
pass
182-
183-
def resolve(self, callback: tuple[int, int], result: TensorDictBase) -> None:
184-
"""Route the result to the correct actor's response queue."""
185-
actor_id, req_id = callback
186-
self._response_queues[actor_id].put((req_id, result))
187-
188-
def resolve_exception(self, callback: tuple[int, int], exc: BaseException) -> None:
189-
"""Route an exception to the correct actor's response queue."""
190-
actor_id, req_id = callback
191-
self._response_queues[actor_id].put((req_id, exc))
94+
return super().client()

0 commit comments

Comments
 (0)