Skip to content

Commit 19015d5

Browse files
vmoenscursoragent
andcommitted
[Feature] Auto-batching inference server: weight sync integration
Wires WeightSyncScheme into the server loop: - init_on_receiver + connect at startup - Non-blocking receive() poll between inference batches - threading.Lock protects model during weight updates - End-to-end tests and updated Sphinx docs with usage tutorial Co-authored-by: Cursor <cursoragent@cursor.com> ghstack-source-id: 50480f5 Pull-Request: #3497
1 parent 80833ec commit 19015d5

File tree

3 files changed

+237
-1
lines changed

3 files changed

+237
-1
lines changed

docs/source/reference/modules_inference_server.rst

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,86 @@ The inference server provides auto-batching model serving for RL actors.
99
Multiple actors submit individual TensorDicts; the server transparently
1010
batches them, runs a single model forward pass, and routes results back.
1111

12+
Core API
13+
--------
14+
1215
.. autosummary::
1316
:toctree: generated/
1417
:template: rl_template_noinherit.rst
1518

1619
InferenceServer
1720
InferenceClient
1821
InferenceTransport
22+
23+
Transport Backends
24+
------------------
25+
26+
.. autosummary::
27+
:toctree: generated/
28+
:template: rl_template_noinherit.rst
29+
30+
ThreadingTransport
31+
MPTransport
32+
RayTransport
33+
MonarchTransport
34+
35+
Usage
36+
-----
37+
38+
The simplest setup uses :class:`ThreadingTransport` for actors that are
39+
threads in the same process:
40+
41+
.. code-block:: python
42+
43+
from tensordict.nn import TensorDictModule
44+
from torchrl.modules.inference_server import (
45+
InferenceServer,
46+
ThreadingTransport,
47+
)
48+
import torch.nn as nn
49+
import concurrent.futures
50+
51+
policy = TensorDictModule(
52+
nn.Sequential(nn.Linear(8, 64), nn.ReLU(), nn.Linear(64, 4)),
53+
in_keys=["observation"],
54+
out_keys=["action"],
55+
)
56+
57+
transport = ThreadingTransport()
58+
server = InferenceServer(policy, transport, max_batch_size=32)
59+
server.start()
60+
client = transport.client()
61+
62+
# actor threads call client(td) -- batched automatically
63+
with concurrent.futures.ThreadPoolExecutor(16) as pool:
64+
...
65+
66+
server.shutdown()
67+
68+
Weight Synchronisation
69+
^^^^^^^^^^^^^^^^^^^^^^
70+
71+
The server integrates with :class:`~torchrl.weight_update.WeightSyncScheme`
72+
to receive updated model weights from a trainer between inference batches:
73+
74+
.. code-block:: python
75+
76+
from torchrl.weight_update import SharedMemWeightSyncScheme
77+
78+
weight_sync = SharedMemWeightSyncScheme()
79+
# Initialise on the trainer (sender) side first
80+
weight_sync.init_on_sender(model=training_model, ...)
81+
82+
server = InferenceServer(
83+
model=inference_model,
84+
transport=ThreadingTransport(),
85+
weight_sync=weight_sync,
86+
)
87+
server.start()
88+
89+
# Training loop
90+
for batch in dataloader:
91+
loss = loss_fn(training_model(batch))
92+
loss.backward()
93+
optimizer.step()
94+
weight_sync.send(model=training_model) # pushed to server

test/test_inference_server.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,3 +557,128 @@ def test_import_without_monarch(self):
557557
def test_instantiation_without_monarch_raises(self):
558558
with pytest.raises(ImportError, match="Monarch is required"):
559559
MonarchTransport()
560+
561+
562+
# =============================================================================
563+
# Tests: WeightSyncScheme integration (Commit 6)
564+
# =============================================================================
565+
566+
567+
class _SimpleWeightSync:
568+
"""Minimal mock that mimics the WeightSyncScheme receiver interface.
569+
570+
Stores a queue of weight TensorDicts. ``receive(timeout=...)`` pops
571+
the next one and applies it to the model via
572+
``TensorDict.from_module / to_module``.
573+
"""
574+
575+
def __init__(self):
576+
self._queue: list[TensorDictBase] = []
577+
self._model = None
578+
self.initialized_on_receiver = False
579+
self.synchronized_on_receiver = False
580+
581+
def init_on_receiver(self, *, model_id, model=None, worker_idx=0, **kwargs):
582+
self._model = model
583+
self.initialized_on_receiver = True
584+
585+
def connect(self, *, worker_idx=0):
586+
self.synchronized_on_receiver = True
587+
588+
def receive(self, timeout=None):
589+
if self._queue:
590+
weights = self._queue.pop(0)
591+
weights.to_module(self._model)
592+
return weights
593+
return None
594+
595+
def push(self, weights: TensorDictBase):
596+
"""Test helper: enqueue weights for the server to pick up."""
597+
self._queue.append(weights)
598+
599+
600+
class TestWeightSyncIntegration:
601+
def test_weight_sync_init_called(self):
602+
"""Server calls init_on_receiver and connect at startup."""
603+
transport = ThreadingTransport()
604+
policy = _make_policy()
605+
ws = _SimpleWeightSync()
606+
607+
with InferenceServer(policy, transport, weight_sync=ws):
608+
# Give the worker thread a moment to start
609+
import time
610+
611+
time.sleep(0.1)
612+
assert ws.initialized_on_receiver
613+
assert ws.synchronized_on_receiver
614+
615+
def test_weight_update_applied(self):
616+
"""Weights pushed via weight_sync are applied to the model."""
617+
transport = ThreadingTransport()
618+
policy = _make_policy()
619+
ws = _SimpleWeightSync()
620+
621+
with InferenceServer(
622+
policy, transport, max_batch_size=4, weight_sync=ws
623+
) as server:
624+
client = transport.client()
625+
626+
# Get initial prediction
627+
td = TensorDict({"observation": torch.ones(4)})
628+
client(td)
629+
630+
# Mutate the model weights externally and push via weight_sync
631+
new_weights = TensorDict.from_module(policy)
632+
for key in new_weights.keys(True, True):
633+
new_weights[key] = torch.zeros_like(new_weights[key])
634+
ws.push(new_weights)
635+
636+
# Give the server loop a chance to apply the update
637+
import time
638+
639+
time.sleep(0.2)
640+
641+
# Now inference should reflect zero weights
642+
result_after = client(td)
643+
# With zero weights the linear output should be zero (bias=0 too)
644+
assert torch.allclose(result_after["action"], torch.zeros(2), atol=1e-6)
645+
646+
def test_inference_continues_after_weight_update(self):
647+
"""The server keeps serving after a weight update."""
648+
transport = ThreadingTransport()
649+
policy = _make_policy()
650+
ws = _SimpleWeightSync()
651+
652+
with InferenceServer(policy, transport, max_batch_size=4, weight_sync=ws):
653+
client = transport.client()
654+
655+
# Initial requests
656+
for _ in range(5):
657+
td = TensorDict({"observation": torch.randn(4)})
658+
result = client(td)
659+
assert "action" in result.keys()
660+
661+
# Push weight update
662+
new_weights = TensorDict.from_module(policy)
663+
ws.push(new_weights)
664+
665+
import time
666+
667+
time.sleep(0.1)
668+
669+
# Continue making requests
670+
for _ in range(5):
671+
td = TensorDict({"observation": torch.randn(4)})
672+
result = client(td)
673+
assert "action" in result.keys()
674+
assert result["action"].shape == (2,)
675+
676+
def test_no_weight_sync(self):
677+
"""Server works fine when weight_sync is None."""
678+
transport = ThreadingTransport()
679+
policy = _make_policy()
680+
with InferenceServer(policy, transport, max_batch_size=4):
681+
client = transport.client()
682+
td = TensorDict({"observation": torch.randn(4)})
683+
result = client(td)
684+
assert "action" in result.keys()

torchrl/modules/inference_server/_server.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ class InferenceServer:
4242
:class:`~torchrl.weight_update.WeightSyncScheme` used to receive
4343
updated model weights from a trainer. When set, the server polls
4444
for new weights between inference batches.
45+
weight_sync_model_id (str, optional): the model identifier used when
46+
initialising the weight sync scheme on the receiver side.
47+
Default: ``"policy"``.
4548
4649
Example:
4750
>>> from tensordict.nn import TensorDictModule
@@ -71,6 +74,7 @@ def __init__(
7174
collate_fn: Callable | None = None,
7275
device: torch.device | str | None = None,
7376
weight_sync=None,
77+
weight_sync_model_id: str = "policy",
7478
):
7579
self.model = model
7680
self.transport = transport
@@ -79,9 +83,12 @@ def __init__(
7983
self.collate_fn = collate_fn if collate_fn is not None else lazy_stack
8084
self.device = torch.device(device) if device is not None else None
8185
self.weight_sync = weight_sync
86+
self._weight_sync_model_id = weight_sync_model_id
8287

8388
self._shutdown_event = threading.Event()
8489
self._worker: threading.Thread | None = None
90+
# Protects model access during weight updates
91+
self._model_lock = threading.Lock()
8592

8693
# -- lifecycle ------------------------------------------------------------
8794

@@ -119,9 +126,36 @@ def is_alive(self) -> bool:
119126

120127
# -- background loop ------------------------------------------------------
121128

129+
def _init_weight_sync(self) -> None:
130+
"""Initialise the weight sync scheme on the receiver (server) side."""
131+
ws = self.weight_sync
132+
if ws is None:
133+
return
134+
if not ws.initialized_on_receiver:
135+
ws.init_on_receiver(
136+
model_id=self._weight_sync_model_id,
137+
model=self.model,
138+
worker_idx=0,
139+
)
140+
if not ws.synchronized_on_receiver:
141+
ws.connect(worker_idx=0)
142+
143+
def _poll_weight_update(self) -> None:
144+
"""Non-blocking check for fresh weights from the trainer."""
145+
ws = self.weight_sync
146+
if ws is None:
147+
return
148+
with self._model_lock:
149+
ws.receive(timeout=0.0)
150+
122151
@torch.no_grad()
123152
def _run(self) -> None:
153+
self._init_weight_sync()
154+
124155
while not self._shutdown_event.is_set():
156+
# Poll for weight updates between batches (non-blocking)
157+
self._poll_weight_update()
158+
125159
self.transport.wait_for_work(timeout=self.timeout)
126160

127161
items, callbacks = self.transport.drain(self.max_batch_size)
@@ -133,7 +167,8 @@ def _run(self) -> None:
133167
batch = batch.to(self.device)
134168

135169
try:
136-
results = self.model(batch).unbind(0)
170+
with self._model_lock:
171+
results = self.model(batch).unbind(0)
137172
if len(results) != len(callbacks):
138173
raise RuntimeError(
139174
f"Model returned {len(results)} results for a "

0 commit comments

Comments
 (0)