Skip to content

Commit 88ff619

Browse files
committed
Update
[ghstack-poisoned]
2 parents 50f9bf4 + 6d6daa4 commit 88ff619

File tree

6 files changed

+458
-461
lines changed

6 files changed

+458
-461
lines changed

test/test_inference_server.py

Lines changed: 119 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import concurrent.futures
88
import threading
9+
import time
910

1011
import pytest
1112
import torch
@@ -21,6 +22,7 @@
2122
InferenceTransport,
2223
MPTransport,
2324
RayTransport,
25+
SlotTransport,
2426
ThreadingTransport,
2527
)
2628
from torchrl.modules.inference_server._monarch import MonarchTransport
@@ -606,8 +608,6 @@ def test_weight_sync_init_called(self):
606608

607609
with InferenceServer(policy, transport, weight_sync=ws):
608610
# Give the worker thread a moment to start
609-
import time
610-
611611
time.sleep(0.1)
612612
assert ws.initialized_on_receiver
613613
assert ws.synchronized_on_receiver
@@ -634,8 +634,6 @@ def test_weight_update_applied(self):
634634
ws.push(new_weights)
635635

636636
# Give the server loop a chance to apply the update
637-
import time
638-
639637
time.sleep(0.2)
640638

641639
# Now inference should reflect zero weights
@@ -662,8 +660,6 @@ def test_inference_continues_after_weight_update(self):
662660
new_weights = TensorDict.from_module(policy)
663661
ws.push(new_weights)
664662

665-
import time
666-
667663
time.sleep(0.1)
668664

669665
# Continue making requests
@@ -868,3 +864,120 @@ def postproc(td):
868864
pass
869865
collector.shutdown()
870866
assert called["count"] >= 1
867+
868+
869+
# =============================================================================
870+
# Tests: SlotTransport
871+
# =============================================================================
872+
873+
874+
class TestSlotTransport:
875+
def test_single_request(self):
876+
transport = SlotTransport(num_slots=4)
877+
policy = _make_policy()
878+
with InferenceServer(policy, transport, max_batch_size=4):
879+
client = transport.client()
880+
td = TensorDict({"observation": torch.randn(4)})
881+
result = client(td)
882+
assert "action" in result.keys()
883+
assert result["action"].shape == (2,)
884+
885+
def test_concurrent_actors(self):
886+
"""Multiple threads submit concurrently via slot clients."""
887+
n_actors = 4
888+
n_requests = 30
889+
transport = SlotTransport(num_slots=n_actors)
890+
policy = _make_policy()
891+
892+
results_per_actor: list[list[TensorDictBase]] = [[] for _ in range(n_actors)]
893+
clients = [transport.client() for _ in range(n_actors)]
894+
895+
def actor_fn(actor_id):
896+
for _ in range(n_requests):
897+
td = TensorDict({"observation": torch.randn(4)})
898+
result = clients[actor_id](td)
899+
results_per_actor[actor_id].append(result)
900+
901+
with InferenceServer(policy, transport, max_batch_size=n_actors):
902+
with concurrent.futures.ThreadPoolExecutor(max_workers=n_actors) as pool:
903+
futs = [pool.submit(actor_fn, i) for i in range(n_actors)]
904+
concurrent.futures.wait(futs)
905+
for f in futs:
906+
f.result()
907+
908+
for actor_results in results_per_actor:
909+
assert len(actor_results) == n_requests
910+
for r in actor_results:
911+
assert "action" in r.keys()
912+
assert r["action"].shape == (2,)
913+
914+
def test_too_many_clients_raises(self):
915+
"""Creating more clients than slots raises RuntimeError."""
916+
transport = SlotTransport(num_slots=2)
917+
transport.client()
918+
transport.client()
919+
with pytest.raises(RuntimeError, match="slots"):
920+
transport.client()
921+
922+
def test_submit_raises(self):
923+
"""Direct submit() on SlotTransport is not supported."""
924+
transport = SlotTransport(num_slots=1)
925+
td = TensorDict({"observation": torch.randn(4)})
926+
with pytest.raises(NotImplementedError):
927+
transport.submit(td)
928+
929+
def test_exception_propagates(self):
930+
"""Model exceptions propagate through SlotTransport."""
931+
932+
def bad_model(td):
933+
raise ValueError("slot model error")
934+
935+
transport = SlotTransport(num_slots=1)
936+
with InferenceServer(bad_model, transport, max_batch_size=4):
937+
client = transport.client()
938+
td = TensorDict({"observation": torch.randn(4)})
939+
with pytest.raises(ValueError, match="slot model error"):
940+
client(td)
941+
942+
943+
# =============================================================================
944+
# Tests: min_batch_size
945+
# =============================================================================
946+
947+
948+
class TestMinBatchSize:
949+
def test_min_batch_size_accumulates(self):
950+
"""With min_batch_size > 1, the server waits for enough items."""
951+
min_bs = 4
952+
seen_sizes = []
953+
954+
def tracking_collate(items):
955+
seen_sizes.append(len(items))
956+
return lazy_stack(items)
957+
958+
transport = ThreadingTransport()
959+
policy = _make_policy()
960+
n = 8
961+
962+
with InferenceServer(
963+
policy,
964+
transport,
965+
max_batch_size=16,
966+
min_batch_size=min_bs,
967+
collate_fn=tracking_collate,
968+
timeout=1.0,
969+
):
970+
client = transport.client()
971+
# Submit items from threads to give the server time to accumulate
972+
with concurrent.futures.ThreadPoolExecutor(max_workers=n) as pool:
973+
futs = [
974+
pool.submit(
975+
lambda: client(TensorDict({"observation": torch.randn(4)}))
976+
)
977+
for _ in range(n)
978+
]
979+
for f in futs:
980+
f.result(timeout=10.0)
981+
982+
# At least one batch should have >= min_batch_size items
983+
assert any(s >= min_bs for s in seen_sizes)

torchrl/modules/inference_server/_monarch.py

Lines changed: 45 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -4,103 +4,50 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7-
import queue
87
import threading
9-
import time
10-
from typing import Any
118

12-
from tensordict.base import TensorDictBase
9+
from torchrl.modules.inference_server._queue_transport import (
10+
_QueueInferenceClient,
11+
QueueBasedTransport,
12+
)
1313

14-
from torchrl.modules.inference_server._transport import InferenceTransport
1514

16-
_SENTINEL = object()
15+
class _MonarchRequestQueue:
16+
"""Wrapper around ``MonarchQueue`` that signals a :class:`threading.Event` on put.
1717
18+
Also adapts the Monarch queue API (``get(block=False)``) to the standard
19+
``get_nowait()`` expected by :class:`QueueBasedTransport`.
20+
"""
1821

19-
class _MonarchFuture:
20-
"""Future-like object for Monarch transport results.
22+
def __init__(self, monarch_queue, has_work: threading.Event):
23+
self._queue = monarch_queue
24+
self._has_work = has_work
2125

22-
Args:
23-
client: the :class:`_MonarchInferenceClient` that created this future.
24-
req_id: the unique request identifier within that client.
25-
"""
26+
def put(self, item):
27+
self._queue.put(item)
28+
self._has_work.set()
2629

27-
def __init__(self, client: _MonarchInferenceClient, req_id: int):
28-
self._client = client
29-
self._req_id = req_id
30-
self._result: Any = _SENTINEL
30+
def get(self, timeout=None):
31+
return self._queue.get(timeout=timeout)
3132

32-
def done(self) -> bool:
33-
"""Return ``True`` if the result is available without blocking."""
34-
if self._result is not _SENTINEL:
35-
return True
36-
try:
37-
self._result = self._client._get_result(self._req_id, timeout=0)
38-
except queue.Empty:
39-
return False
40-
return True
41-
42-
def result(self, timeout: float | None = None) -> TensorDictBase:
43-
"""Block until the result is available."""
44-
if self._result is _SENTINEL:
45-
self._result = self._client._get_result(self._req_id, timeout=timeout)
46-
if isinstance(self._result, BaseException):
47-
raise self._result
48-
return self._result
49-
50-
51-
class _MonarchInferenceClient:
52-
"""Actor-side client for :class:`MonarchTransport`.
53-
54-
Each client owns a dedicated response queue and routes results by
55-
request-id.
56-
57-
Args:
58-
request_queue: the shared Monarch queue for requests.
59-
response_queue: this client's dedicated response queue.
60-
actor_id: the unique identifier assigned by the transport.
61-
"""
33+
def get_nowait(self):
34+
return self._queue.get(block=False)
35+
36+
37+
class _MonarchResponseQueue:
38+
"""Thin wrapper adapting the MonarchQueue get API."""
39+
40+
def __init__(self, monarch_queue):
41+
self._queue = monarch_queue
6242

63-
def __init__(self, request_queue, response_queue, actor_id: int):
64-
self._request_queue = request_queue
65-
self._response_queue = response_queue
66-
self._actor_id = actor_id
67-
self._next_req_id = 0
68-
self._buffered: dict[int, Any] = {}
69-
70-
def __call__(self, td: TensorDictBase) -> TensorDictBase:
71-
"""Submit a request and block until the result is ready."""
72-
return self.submit(td).result()
73-
74-
def submit(self, td: TensorDictBase) -> _MonarchFuture:
75-
"""Submit a request and return a :class:`_MonarchFuture`."""
76-
req_id = self._next_req_id
77-
self._next_req_id += 1
78-
self._request_queue.put((self._actor_id, req_id, td))
79-
return _MonarchFuture(self, req_id)
80-
81-
# -- internal -------------------------------------------------------------
82-
83-
def _get_result(self, req_id: int, timeout: float | None = None) -> Any:
84-
"""Return the result for *req_id*, buffering any earlier arrivals."""
85-
if req_id in self._buffered:
86-
return self._buffered.pop(req_id)
87-
deadline = None if timeout is None else time.monotonic() + timeout
88-
while True:
89-
remaining = None
90-
if deadline is not None:
91-
remaining = deadline - time.monotonic()
92-
if remaining <= 0:
93-
raise queue.Empty(f"Timeout waiting for result of request {req_id}")
94-
try:
95-
rid, result = self._response_queue.get(timeout=remaining)
96-
except Exception:
97-
raise queue.Empty(f"Timeout waiting for result of request {req_id}")
98-
if rid == req_id:
99-
return result
100-
self._buffered[rid] = result
101-
102-
103-
class MonarchTransport(InferenceTransport):
43+
def put(self, item):
44+
self._queue.put(item)
45+
46+
def get(self, timeout=None):
47+
return self._queue.get(timeout=timeout)
48+
49+
50+
class MonarchTransport(QueueBasedTransport):
10451
"""Transport using Monarch for distributed inference on GPU clusters.
10552
10653
Uses Monarch's actor model and RDMA-capable channels for efficient
@@ -118,6 +65,7 @@ class MonarchTransport(InferenceTransport):
11865
"""
11966

12067
def __init__(self, *, max_queue_size: int = 1000):
68+
super().__init__()
12169
try:
12270
import monarch # noqa: F401
12371
from monarch.tools.queue import MonarchQueue
@@ -126,66 +74,21 @@ def __init__(self, *, max_queue_size: int = 1000):
12674
"Monarch is required for MonarchTransport. "
12775
"Install it following the Monarch documentation."
12876
)
129-
self._request_queue = MonarchQueue(maxsize=max_queue_size)
130-
self._response_queues: dict[int, Any] = {}
131-
self._lock = threading.Lock()
132-
self._next_actor_id = 0
77+
self._has_work = threading.Event()
78+
self._request_queue = _MonarchRequestQueue(
79+
MonarchQueue(maxsize=max_queue_size), self._has_work
80+
)
81+
self._response_queues: dict[int, _MonarchResponseQueue] = {}
13382
self._MonarchQueue = MonarchQueue
13483

135-
# -- actor API ------------------------------------------------------------
84+
def _make_response_queue(self) -> _MonarchResponseQueue:
85+
return _MonarchResponseQueue(self._MonarchQueue(maxsize=1000))
13686

137-
def client(self) -> _MonarchInferenceClient:
87+
def client(self) -> _QueueInferenceClient:
13888
"""Create an actor-side client with a dedicated response queue.
13989
14090
Returns:
141-
A :class:`_MonarchInferenceClient` that can be passed to a Monarch
91+
A :class:`_QueueInferenceClient` that can be passed to a Monarch
14292
actor.
14393
"""
144-
with self._lock:
145-
actor_id = self._next_actor_id
146-
self._next_actor_id += 1
147-
response_queue = self._MonarchQueue(maxsize=1000)
148-
self._response_queues[actor_id] = response_queue
149-
return _MonarchInferenceClient(self._request_queue, response_queue, actor_id)
150-
151-
def submit(self, td: TensorDictBase):
152-
"""Not supported -- use :meth:`client` to obtain an actor handle."""
153-
raise RuntimeError(
154-
"MonarchTransport.submit() is not supported. "
155-
"Call transport.client() to create a _MonarchInferenceClient."
156-
)
157-
158-
# -- server API -----------------------------------------------------------
159-
160-
def drain(
161-
self, max_items: int
162-
) -> tuple[list[TensorDictBase], list[tuple[int, int]]]:
163-
"""Dequeue up to *max_items* pending requests (non-blocking)."""
164-
items: list[TensorDictBase] = []
165-
callbacks: list[tuple[int, int]] = []
166-
for _ in range(max_items):
167-
try:
168-
actor_id, req_id, td = self._request_queue.get(block=False)
169-
items.append(td)
170-
callbacks.append((actor_id, req_id))
171-
except Exception:
172-
break
173-
return items, callbacks
174-
175-
def wait_for_work(self, timeout: float) -> None:
176-
"""Block until at least one request is available or *timeout* elapses."""
177-
try:
178-
item = self._request_queue.get(timeout=timeout)
179-
self._request_queue.put(item)
180-
except Exception:
181-
pass
182-
183-
def resolve(self, callback: tuple[int, int], result: TensorDictBase) -> None:
184-
"""Route the result to the correct actor's response queue."""
185-
actor_id, req_id = callback
186-
self._response_queues[actor_id].put((req_id, result))
187-
188-
def resolve_exception(self, callback: tuple[int, int], exc: BaseException) -> None:
189-
"""Route an exception to the correct actor's response queue."""
190-
actor_id, req_id = callback
191-
self._response_queues[actor_id].put((req_id, exc))
94+
return super().client()

0 commit comments

Comments
 (0)