Skip to content

Commit 509448b

Browse files
vmoenscursoragent
andcommitted
[Feature] Auto-batching inference server: weight sync integration (#3497)
Wires WeightSyncScheme into the server loop: - init_on_receiver + connect at startup - Non-blocking receive() poll between inference batches - threading.Lock protects model during weight updates - End-to-end tests and updated Sphinx docs with usage tutorial Co-authored-by: Cursor <cursoragent@cursor.com> ghstack-source-id: bbb55bd Pull-Request: #3497 Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 4083720 commit 509448b

File tree

3 files changed

+231
-1
lines changed

3 files changed

+231
-1
lines changed

docs/source/reference/modules_inference_server.rst

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,86 @@ The inference server provides auto-batching model serving for RL actors.
99
Multiple actors submit individual TensorDicts; the server transparently
1010
batches them, runs a single model forward pass, and routes results back.
1111

12+
Core API
13+
--------
14+
1215
.. autosummary::
1316
:toctree: generated/
1417
:template: rl_template_noinherit.rst
1518

1619
InferenceServer
1720
InferenceClient
1821
InferenceTransport
22+
23+
Transport Backends
24+
------------------
25+
26+
.. autosummary::
27+
:toctree: generated/
28+
:template: rl_template_noinherit.rst
29+
30+
ThreadingTransport
31+
MPTransport
32+
RayTransport
33+
MonarchTransport
34+
35+
Usage
36+
-----
37+
38+
The simplest setup uses :class:`ThreadingTransport` for actors that are
39+
threads in the same process:
40+
41+
.. code-block:: python
42+
43+
from tensordict.nn import TensorDictModule
44+
from torchrl.modules.inference_server import (
45+
InferenceServer,
46+
ThreadingTransport,
47+
)
48+
import torch.nn as nn
49+
import concurrent.futures
50+
51+
policy = TensorDictModule(
52+
nn.Sequential(nn.Linear(8, 64), nn.ReLU(), nn.Linear(64, 4)),
53+
in_keys=["observation"],
54+
out_keys=["action"],
55+
)
56+
57+
transport = ThreadingTransport()
58+
server = InferenceServer(policy, transport, max_batch_size=32)
59+
server.start()
60+
client = transport.client()
61+
62+
# actor threads call client(td) -- batched automatically
63+
with concurrent.futures.ThreadPoolExecutor(16) as pool:
64+
...
65+
66+
server.shutdown()
67+
68+
Weight Synchronisation
69+
^^^^^^^^^^^^^^^^^^^^^^
70+
71+
The server integrates with :class:`~torchrl.weight_update.WeightSyncScheme`
72+
to receive updated model weights from a trainer between inference batches:
73+
74+
.. code-block:: python
75+
76+
from torchrl.weight_update import SharedMemWeightSyncScheme
77+
78+
weight_sync = SharedMemWeightSyncScheme()
79+
# Initialise on the trainer (sender) side first
80+
weight_sync.init_on_sender(model=training_model, ...)
81+
82+
server = InferenceServer(
83+
model=inference_model,
84+
transport=ThreadingTransport(),
85+
weight_sync=weight_sync,
86+
)
87+
server.start()
88+
89+
# Training loop
90+
for batch in dataloader:
91+
loss = loss_fn(training_model(batch))
92+
loss.backward()
93+
optimizer.step()
94+
weight_sync.send(model=training_model) # pushed to server

test/test_inference_server.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import concurrent.futures
88
import threading
9+
import time
910

1011
import pytest
1112
import torch
@@ -557,3 +558,122 @@ def test_import_without_monarch(self):
557558
def test_instantiation_without_monarch_raises(self):
558559
with pytest.raises(ImportError, match="Monarch is required"):
559560
MonarchTransport()
561+
562+
563+
# =============================================================================
564+
# Tests: WeightSyncScheme integration (Commit 6)
565+
# =============================================================================
566+
567+
568+
class _SimpleWeightSync:
569+
"""Minimal mock that mimics the WeightSyncScheme receiver interface.
570+
571+
Stores a queue of weight TensorDicts. ``receive(timeout=...)`` pops
572+
the next one and applies it to the model via
573+
``TensorDict.from_module / to_module``.
574+
"""
575+
576+
def __init__(self):
577+
self._queue: list[TensorDictBase] = []
578+
self._model = None
579+
self.initialized_on_receiver = False
580+
self.synchronized_on_receiver = False
581+
582+
def init_on_receiver(self, *, model_id, model=None, worker_idx=0, **kwargs):
583+
self._model = model
584+
self.initialized_on_receiver = True
585+
586+
def connect(self, *, worker_idx=0):
587+
self.synchronized_on_receiver = True
588+
589+
def receive(self, timeout=None):
590+
if self._queue:
591+
weights = self._queue.pop(0)
592+
weights.to_module(self._model)
593+
return weights
594+
return None
595+
596+
def push(self, weights: TensorDictBase):
597+
"""Test helper: enqueue weights for the server to pick up."""
598+
self._queue.append(weights)
599+
600+
601+
class TestWeightSyncIntegration:
602+
def test_weight_sync_init_called(self):
603+
"""Server calls init_on_receiver and connect at startup."""
604+
transport = ThreadingTransport()
605+
policy = _make_policy()
606+
ws = _SimpleWeightSync()
607+
608+
with InferenceServer(policy, transport, weight_sync=ws):
609+
# Give the worker thread a moment to start
610+
time.sleep(0.1)
611+
assert ws.initialized_on_receiver
612+
assert ws.synchronized_on_receiver
613+
614+
def test_weight_update_applied(self):
615+
"""Weights pushed via weight_sync are applied to the model."""
616+
transport = ThreadingTransport()
617+
policy = _make_policy()
618+
ws = _SimpleWeightSync()
619+
620+
with InferenceServer(
621+
policy, transport, max_batch_size=4, weight_sync=ws
622+
) as server:
623+
client = transport.client()
624+
625+
# Get initial prediction
626+
td = TensorDict({"observation": torch.ones(4)})
627+
client(td)
628+
629+
# Mutate the model weights externally and push via weight_sync
630+
new_weights = TensorDict.from_module(policy)
631+
for key in new_weights.keys(True, True):
632+
new_weights[key] = torch.zeros_like(new_weights[key])
633+
ws.push(new_weights)
634+
635+
# Give the server loop a chance to apply the update
636+
time.sleep(0.2)
637+
638+
# Now inference should reflect zero weights
639+
result_after = client(td)
640+
# With zero weights the linear output should be zero (bias=0 too)
641+
assert torch.allclose(result_after["action"], torch.zeros(2), atol=1e-6)
642+
643+
def test_inference_continues_after_weight_update(self):
644+
"""The server keeps serving after a weight update."""
645+
transport = ThreadingTransport()
646+
policy = _make_policy()
647+
ws = _SimpleWeightSync()
648+
649+
with InferenceServer(policy, transport, max_batch_size=4, weight_sync=ws):
650+
client = transport.client()
651+
652+
# Initial requests
653+
for _ in range(5):
654+
td = TensorDict({"observation": torch.randn(4)})
655+
result = client(td)
656+
assert "action" in result.keys()
657+
658+
# Push weight update
659+
new_weights = TensorDict.from_module(policy)
660+
ws.push(new_weights)
661+
662+
time.sleep(0.1)
663+
664+
# Continue making requests
665+
for _ in range(5):
666+
td = TensorDict({"observation": torch.randn(4)})
667+
result = client(td)
668+
assert "action" in result.keys()
669+
assert result["action"].shape == (2,)
670+
671+
def test_no_weight_sync(self):
672+
"""Server works fine when weight_sync is None."""
673+
transport = ThreadingTransport()
674+
policy = _make_policy()
675+
with InferenceServer(policy, transport, max_batch_size=4):
676+
client = transport.client()
677+
td = TensorDict({"observation": torch.randn(4)})
678+
result = client(td)
679+
assert "action" in result.keys()

torchrl/modules/inference_server/_server.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ class InferenceServer:
4242
:class:`~torchrl.weight_update.WeightSyncScheme` used to receive
4343
updated model weights from a trainer. When set, the server polls
4444
for new weights between inference batches.
45+
weight_sync_model_id (str, optional): the model identifier used when
46+
initialising the weight sync scheme on the receiver side.
47+
Default: ``"policy"``.
4548
4649
Example:
4750
>>> from tensordict.nn import TensorDictModule
@@ -71,6 +74,7 @@ def __init__(
7174
collate_fn: Callable | None = None,
7275
device: torch.device | str | None = None,
7376
weight_sync=None,
77+
weight_sync_model_id: str = "policy",
7478
):
7579
self.model = model
7680
self.transport = transport
@@ -79,9 +83,12 @@ def __init__(
7983
self.collate_fn = collate_fn if collate_fn is not None else lazy_stack
8084
self.device = torch.device(device) if device is not None else None
8185
self.weight_sync = weight_sync
86+
self._weight_sync_model_id = weight_sync_model_id
8287

8388
self._shutdown_event = threading.Event()
8489
self._worker: threading.Thread | None = None
90+
# Protects model access during weight updates
91+
self._model_lock = threading.Lock()
8592

8693
# -- lifecycle ------------------------------------------------------------
8794

@@ -119,10 +126,36 @@ def is_alive(self) -> bool:
119126

120127
# -- background loop ------------------------------------------------------
121128

129+
def _init_weight_sync(self) -> None:
130+
"""Initialise the weight sync scheme on the receiver (server) side."""
131+
ws = self.weight_sync
132+
if ws is None:
133+
return
134+
if not ws.initialized_on_receiver:
135+
ws.init_on_receiver(
136+
model_id=self._weight_sync_model_id,
137+
model=self.model,
138+
worker_idx=0,
139+
)
140+
if not ws.synchronized_on_receiver:
141+
ws.connect(worker_idx=0)
142+
143+
def _poll_weight_update(self) -> None:
144+
"""Non-blocking check for fresh weights from the trainer."""
145+
ws = self.weight_sync
146+
if ws is None:
147+
return
148+
with self._model_lock:
149+
ws.receive(timeout=0.0)
150+
122151
@torch.no_grad()
123152
def _run(self) -> None:
153+
self._init_weight_sync()
154+
124155
try:
125156
while not self._shutdown_event.is_set():
157+
self._poll_weight_update()
158+
126159
self.transport.wait_for_work(timeout=self.timeout)
127160

128161
items, callbacks = self.transport.drain(self.max_batch_size)
@@ -134,7 +167,8 @@ def _run(self) -> None:
134167
batch = batch.to(self.device)
135168

136169
try:
137-
results = self.model(batch).unbind(0)
170+
with self._model_lock:
171+
results = self.model(batch).unbind(0)
138172
if len(results) != len(callbacks):
139173
raise RuntimeError(
140174
f"Model returned {len(results)} results for a "

0 commit comments

Comments
 (0)