Skip to content

Commit 3b86bdc

Browse files
committed
Update
[ghstack-poisoned]
1 parent ff7cff7 commit 3b86bdc

File tree

6 files changed

+525
-0
lines changed

6 files changed

+525
-0
lines changed

docs/source/reference/modules.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,5 @@ Documentation Sections
5656
modules_mcts
5757
modules_models
5858
modules_distributions
59+
modules_inference_server
5960
modules_utils
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
.. currentmodule:: torchrl.modules.inference_server
2+
3+
Inference Server
4+
================
5+
6+
.. _ref_inference_server:
7+
8+
The inference server provides auto-batching model serving for RL actors.
9+
Multiple actors submit individual TensorDicts; the server transparently
10+
batches them, runs a single model forward pass, and routes results back.
11+
12+
.. autosummary::
13+
:toctree: generated/
14+
:template: rl_template_noinherit.rst
15+
16+
InferenceServer
17+
InferenceClient
18+
InferenceTransport

test/test_inference_server.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
import concurrent.futures
8+
import threading
9+
10+
import pytest
11+
import torch
12+
import torch.nn as nn
13+
14+
from tensordict import lazy_stack, TensorDict
15+
from tensordict.base import TensorDictBase
16+
from tensordict.nn import TensorDictModule
17+
18+
from torchrl.modules.inference_server import (
19+
InferenceClient,
20+
InferenceServer,
21+
InferenceTransport,
22+
)
23+
24+
25+
# =============================================================================
26+
# Helpers
27+
# =============================================================================
28+
29+
30+
class _MockTransport(InferenceTransport):
31+
"""Minimal in-process transport for testing the core server logic."""
32+
33+
def __init__(self):
34+
self._queue: list[TensorDictBase] = []
35+
self._futures: list[concurrent.futures.Future] = []
36+
self._lock = threading.Lock()
37+
self._event = threading.Event()
38+
39+
def submit(self, td):
40+
fut = concurrent.futures.Future()
41+
with self._lock:
42+
self._queue.append(td)
43+
self._futures.append(fut)
44+
self._event.set()
45+
return fut
46+
47+
def drain(self, max_items):
48+
with self._lock:
49+
n = min(len(self._queue), max_items)
50+
items = self._queue[:n]
51+
futs = self._futures[:n]
52+
del self._queue[:n]
53+
del self._futures[:n]
54+
return items, futs
55+
56+
def wait_for_work(self, timeout):
57+
self._event.wait(timeout=timeout)
58+
self._event.clear()
59+
60+
def resolve(self, callback, result):
61+
callback.set_result(result)
62+
63+
def resolve_exception(self, callback, exc):
64+
callback.set_exception(exc)
65+
66+
67+
def _make_policy():
68+
"""A simple TensorDictModule for testing."""
69+
return TensorDictModule(
70+
nn.Linear(4, 2),
71+
in_keys=["observation"],
72+
out_keys=["action"],
73+
)
74+
75+
76+
# =============================================================================
77+
# Tests: core abstractions (Commit 1)
78+
# =============================================================================
79+
80+
81+
class TestInferenceTransportABC:
82+
def test_cannot_instantiate(self):
83+
with pytest.raises(TypeError):
84+
InferenceTransport()
85+
86+
def test_client_returns_inference_client(self):
87+
transport = _MockTransport()
88+
client = transport.client()
89+
assert isinstance(client, InferenceClient)
90+
91+
92+
class TestInferenceServerCore:
93+
def test_start_and_shutdown(self):
94+
transport = _MockTransport()
95+
policy = _make_policy()
96+
server = InferenceServer(policy, transport, max_batch_size=4)
97+
server.start()
98+
assert server.is_alive
99+
server.shutdown()
100+
assert not server.is_alive
101+
102+
def test_context_manager(self):
103+
transport = _MockTransport()
104+
policy = _make_policy()
105+
with InferenceServer(policy, transport, max_batch_size=4) as server:
106+
assert server.is_alive
107+
assert not server.is_alive
108+
109+
def test_double_start_raises(self):
110+
transport = _MockTransport()
111+
policy = _make_policy()
112+
server = InferenceServer(policy, transport, max_batch_size=4)
113+
server.start()
114+
try:
115+
with pytest.raises(RuntimeError, match="already running"):
116+
server.start()
117+
finally:
118+
server.shutdown()
119+
120+
def test_single_request(self):
121+
transport = _MockTransport()
122+
policy = _make_policy()
123+
with InferenceServer(policy, transport, max_batch_size=4):
124+
td = TensorDict({"observation": torch.randn(4)})
125+
fut = transport.submit(td)
126+
result = fut.result(timeout=5.0)
127+
assert "action" in result.keys()
128+
assert result["action"].shape == (2,)
129+
130+
def test_batch_of_requests(self):
131+
transport = _MockTransport()
132+
policy = _make_policy()
133+
n = 8
134+
with InferenceServer(policy, transport, max_batch_size=16):
135+
futures = [
136+
transport.submit(TensorDict({"observation": torch.randn(4)}))
137+
for _ in range(n)
138+
]
139+
results = [f.result(timeout=5.0) for f in futures]
140+
assert len(results) == n
141+
for r in results:
142+
assert "action" in r.keys()
143+
assert r["action"].shape == (2,)
144+
145+
def test_collate_fn_is_called(self):
146+
calls = []
147+
148+
def tracking_collate(items):
149+
calls.append(len(items))
150+
return lazy_stack(items)
151+
152+
transport = _MockTransport()
153+
policy = _make_policy()
154+
with InferenceServer(
155+
policy, transport, max_batch_size=16, collate_fn=tracking_collate
156+
):
157+
futures = [
158+
transport.submit(TensorDict({"observation": torch.randn(4)}))
159+
for _ in range(4)
160+
]
161+
for f in futures:
162+
f.result(timeout=5.0)
163+
164+
assert len(calls) >= 1
165+
assert sum(calls) == 4 # all 4 items processed
166+
167+
def test_max_batch_size_respected(self):
168+
"""The collate_fn should never receive more than max_batch_size items."""
169+
max_bs = 4
170+
seen_sizes = []
171+
172+
def tracking_collate(items):
173+
seen_sizes.append(len(items))
174+
return lazy_stack(items)
175+
176+
transport = _MockTransport()
177+
policy = _make_policy()
178+
# Submit many items then start the server
179+
n = 20
180+
futures = [
181+
transport.submit(TensorDict({"observation": torch.randn(4)}))
182+
for _ in range(n)
183+
]
184+
with InferenceServer(
185+
policy,
186+
transport,
187+
max_batch_size=max_bs,
188+
collate_fn=tracking_collate,
189+
):
190+
for f in futures:
191+
f.result(timeout=5.0)
192+
193+
for s in seen_sizes:
194+
assert s <= max_bs
195+
196+
197+
class TestInferenceClient:
198+
def test_sync_call(self):
199+
transport = _MockTransport()
200+
policy = _make_policy()
201+
with InferenceServer(policy, transport, max_batch_size=4):
202+
client = InferenceClient(transport)
203+
td = TensorDict({"observation": torch.randn(4)})
204+
result = client(td)
205+
assert "action" in result.keys()
206+
207+
def test_submit_returns_future(self):
208+
transport = _MockTransport()
209+
policy = _make_policy()
210+
with InferenceServer(policy, transport, max_batch_size=4):
211+
client = InferenceClient(transport)
212+
td = TensorDict({"observation": torch.randn(4)})
213+
fut = client.submit(td)
214+
assert isinstance(fut, concurrent.futures.Future)
215+
result = fut.result(timeout=5.0)
216+
assert "action" in result.keys()
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from torchrl.modules.inference_server._server import InferenceClient, InferenceServer
7+
from torchrl.modules.inference_server._transport import InferenceTransport
8+
9+
__all__ = [
10+
"InferenceClient",
11+
"InferenceServer",
12+
"InferenceTransport",
13+
]

0 commit comments

Comments
 (0)