diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 6556c3b81b6..04555678ab0 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -56,4 +56,5 @@ Documentation Sections modules_mcts modules_models modules_distributions + modules_inference_server modules_utils diff --git a/docs/source/reference/modules_inference_server.rst b/docs/source/reference/modules_inference_server.rst new file mode 100644 index 00000000000..980eb4d0849 --- /dev/null +++ b/docs/source/reference/modules_inference_server.rst @@ -0,0 +1,18 @@ +.. currentmodule:: torchrl.modules.inference_server + +Inference Server +================ + +.. _ref_inference_server: + +The inference server provides auto-batching model serving for RL actors. +Multiple actors submit individual TensorDicts; the server transparently +batches them, runs a single model forward pass, and routes results back. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + InferenceServer + InferenceClient + InferenceTransport diff --git a/test/test_inference_server.py b/test/test_inference_server.py new file mode 100644 index 00000000000..9fe8dac757d --- /dev/null +++ b/test/test_inference_server.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import concurrent.futures +import threading + +import pytest +import torch +import torch.nn as nn + +from tensordict import lazy_stack, TensorDict +from tensordict.base import TensorDictBase +from tensordict.nn import TensorDictModule + +from torchrl.modules.inference_server import ( + InferenceClient, + InferenceServer, + InferenceTransport, +) + + +# ============================================================================= +# Helpers +# ============================================================================= + + +class _MockTransport(InferenceTransport): + """Minimal in-process transport for testing the core server logic.""" + + def __init__(self): + self._queue: list[TensorDictBase] = [] + self._futures: list[concurrent.futures.Future] = [] + self._lock = threading.Lock() + self._event = threading.Event() + + def submit(self, td): + fut = concurrent.futures.Future() + with self._lock: + self._queue.append(td) + self._futures.append(fut) + self._event.set() + return fut + + def drain(self, max_items): + with self._lock: + n = min(len(self._queue), max_items) + items = self._queue[:n] + futs = self._futures[:n] + del self._queue[:n] + del self._futures[:n] + return items, futs + + def wait_for_work(self, timeout): + self._event.wait(timeout=timeout) + self._event.clear() + + def resolve(self, callback, result): + callback.set_result(result) + + def resolve_exception(self, callback, exc): + callback.set_exception(exc) + + +def _make_policy(): + """A simple TensorDictModule for testing.""" + return TensorDictModule( + nn.Linear(4, 2), + in_keys=["observation"], + out_keys=["action"], + ) + + +# ============================================================================= +# Tests: core abstractions (Commit 1) +# ============================================================================= + + +class TestInferenceTransportABC: + def test_cannot_instantiate(self): + with pytest.raises(TypeError): + InferenceTransport() + + def test_client_returns_inference_client(self): + transport = _MockTransport() + client = transport.client() + assert isinstance(client, InferenceClient) + + +class TestInferenceServerCore: + def test_start_and_shutdown(self): + transport = _MockTransport() + policy = _make_policy() + server = InferenceServer(policy, transport, max_batch_size=4) + server.start() + assert server.is_alive + server.shutdown() + assert not server.is_alive + + def test_context_manager(self): + transport = _MockTransport() + policy = _make_policy() + with InferenceServer(policy, transport, max_batch_size=4) as server: + assert server.is_alive + assert not server.is_alive + + def test_double_start_raises(self): + transport = _MockTransport() + policy = _make_policy() + server = InferenceServer(policy, transport, max_batch_size=4) + server.start() + try: + with pytest.raises(RuntimeError, match="already running"): + server.start() + finally: + server.shutdown() + + def test_single_request(self): + transport = _MockTransport() + policy = _make_policy() + with InferenceServer(policy, transport, max_batch_size=4): + td = TensorDict({"observation": torch.randn(4)}) + fut = transport.submit(td) + result = fut.result(timeout=5.0) + assert "action" in result.keys() + assert result["action"].shape == (2,) + + def test_batch_of_requests(self): + transport = _MockTransport() + policy = _make_policy() + n = 8 + with InferenceServer(policy, transport, max_batch_size=16): + futures = [ + transport.submit(TensorDict({"observation": torch.randn(4)})) + for _ in range(n) + ] + results = [f.result(timeout=5.0) for f in futures] + assert len(results) == n + for r in results: + assert "action" in r.keys() + assert r["action"].shape == (2,) + + def test_collate_fn_is_called(self): + calls = [] + + def tracking_collate(items): + calls.append(len(items)) + return lazy_stack(items) + + transport = _MockTransport() + policy = _make_policy() + with InferenceServer( + policy, transport, max_batch_size=16, collate_fn=tracking_collate + ): + futures = [ + transport.submit(TensorDict({"observation": torch.randn(4)})) + for _ in range(4) + ] + for f in futures: + f.result(timeout=5.0) + + assert len(calls) >= 1 + assert sum(calls) == 4 # all 4 items processed + + def test_max_batch_size_respected(self): + """The collate_fn should never receive more than max_batch_size items.""" + max_bs = 4 + seen_sizes = [] + + def tracking_collate(items): + seen_sizes.append(len(items)) + return lazy_stack(items) + + transport = _MockTransport() + policy = _make_policy() + # Submit many items then start the server + n = 20 + futures = [ + transport.submit(TensorDict({"observation": torch.randn(4)})) + for _ in range(n) + ] + with InferenceServer( + policy, + transport, + max_batch_size=max_bs, + collate_fn=tracking_collate, + ): + for f in futures: + f.result(timeout=5.0) + + for s in seen_sizes: + assert s <= max_bs + + +class TestInferenceClient: + def test_sync_call(self): + transport = _MockTransport() + policy = _make_policy() + with InferenceServer(policy, transport, max_batch_size=4): + client = InferenceClient(transport) + td = TensorDict({"observation": torch.randn(4)}) + result = client(td) + assert "action" in result.keys() + + def test_submit_returns_future(self): + transport = _MockTransport() + policy = _make_policy() + with InferenceServer(policy, transport, max_batch_size=4): + client = InferenceClient(transport) + td = TensorDict({"observation": torch.randn(4)}) + fut = client.submit(td) + assert isinstance(fut, concurrent.futures.Future) + result = fut.result(timeout=5.0) + assert "action" in result.keys() diff --git a/torchrl/modules/inference_server/__init__.py b/torchrl/modules/inference_server/__init__.py new file mode 100644 index 00000000000..352246737b7 --- /dev/null +++ b/torchrl/modules/inference_server/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from torchrl.modules.inference_server._server import InferenceClient, InferenceServer +from torchrl.modules.inference_server._transport import InferenceTransport + +__all__ = [ + "InferenceClient", + "InferenceServer", + "InferenceTransport", +] diff --git a/torchrl/modules/inference_server/_server.py b/torchrl/modules/inference_server/_server.py new file mode 100644 index 00000000000..7751b2dac0a --- /dev/null +++ b/torchrl/modules/inference_server/_server.py @@ -0,0 +1,200 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import threading +from collections.abc import Callable +from concurrent.futures import Future + +import torch +from tensordict import lazy_stack +from tensordict.base import TensorDictBase +from torch import nn + +from torchrl.modules.inference_server._transport import InferenceTransport + + +class InferenceServer: + """Auto-batching inference server. + + Actors submit individual TensorDicts via the *transport* and receive + results asynchronously. A background worker drains the transport queue, + batches inputs, runs the model, and fans results back to the callers. + + Args: + model (nn.Module or Callable): a callable that maps a batched + TensorDictBase to a batched TensorDictBase (e.g. a + :class:`~tensordict.nn.TensorDictModule`). + transport (InferenceTransport): the communication backend. + + Keyword Args: + max_batch_size (int, optional): upper bound on the number of requests + processed in a single forward pass. Default: ``64``. + timeout (float, optional): seconds to wait for new work before + dispatching a partial batch. Default: ``0.01``. + collate_fn (Callable, optional): function used to stack a list of + TensorDicts into a batch. Default: :func:`~tensordict.lazy_stack`. + device (torch.device or str, optional): device to move batches to + before calling the model. ``None`` means no device transfer. + weight_sync: an optional + :class:`~torchrl.weight_update.WeightSyncScheme` used to receive + updated model weights from a trainer. When set, the server polls + for new weights between inference batches. + + Example: + >>> from tensordict.nn import TensorDictModule + >>> from torchrl.modules.inference_server import ( + ... InferenceServer, + ... ThreadingTransport, + ... ) + >>> import torch.nn as nn + >>> policy = TensorDictModule( + ... nn.Linear(4, 2), in_keys=["obs"], out_keys=["act"] + ... ) + >>> transport = ThreadingTransport() + >>> server = InferenceServer(policy, transport, max_batch_size=8) + >>> server.start() + >>> client = transport.client() + >>> # client(td) can now be called from any thread + >>> server.shutdown() + """ + + def __init__( + self, + model: nn.Module, + transport: InferenceTransport, + *, + max_batch_size: int = 64, + timeout: float = 0.01, + collate_fn: Callable | None = None, + device: torch.device | str | None = None, + weight_sync=None, + ): + self.model = model + self.transport = transport + self.max_batch_size = max_batch_size + self.timeout = timeout + self.collate_fn = collate_fn if collate_fn is not None else lazy_stack + self.device = torch.device(device) if device is not None else None + self.weight_sync = weight_sync + + self._shutdown_event = threading.Event() + self._worker: threading.Thread | None = None + + # -- lifecycle ------------------------------------------------------------ + + def start(self) -> InferenceServer: + """Start the background inference loop. + + Returns: + self, for fluent chaining. + """ + if self._worker is not None and self._worker.is_alive(): + raise RuntimeError("Server is already running.") + self._shutdown_event.clear() + self._worker = threading.Thread( + target=self._run, daemon=True, name="InferenceServer-worker" + ) + self._worker.start() + return self + + def shutdown(self, timeout: float | None = 5.0) -> None: + """Signal the background worker to stop and wait for it to finish. + + Args: + timeout (float or None): seconds to wait for the worker thread to + join. ``None`` waits indefinitely. + """ + self._shutdown_event.set() + if self._worker is not None: + self._worker.join(timeout=timeout) + self._worker = None + + @property + def is_alive(self) -> bool: + """Whether the background worker thread is running.""" + return self._worker is not None and self._worker.is_alive() + + # -- background loop ------------------------------------------------------ + + @torch.no_grad() + def _run(self) -> None: + try: + while not self._shutdown_event.is_set(): + self.transport.wait_for_work(timeout=self.timeout) + + items, callbacks = self.transport.drain(self.max_batch_size) + if not items: + continue + + batch = self.collate_fn(items) + if self.device is not None: + batch = batch.to(self.device) + + try: + results = self.model(batch).unbind(0) + if len(results) != len(callbacks): + raise RuntimeError( + f"Model returned {len(results)} results for a " + f"batch of {len(callbacks)} inputs." + ) + for cb, res in zip(callbacks, results): + self.transport.resolve(cb, res) + except Exception as exc: + for cb in callbacks: + self.transport.resolve_exception(cb, exc) + finally: + self._drain_pending_on_shutdown() + + def _drain_pending_on_shutdown(self) -> None: + """Resolve all pending requests with an error during shutdown.""" + shutdown_exc = RuntimeError("InferenceServer is shutting down.") + while True: + items, callbacks = self.transport.drain(self.max_batch_size) + if not items: + break + for cb in callbacks: + self.transport.resolve_exception(cb, shutdown_exc) + + # -- context manager ------------------------------------------------------ + + def __enter__(self) -> InferenceServer: + return self.start() + + def __exit__(self, *exc_info) -> None: + self.shutdown() + + def __del__(self) -> None: + if self._worker is not None and self._worker.is_alive(): + self.shutdown(timeout=1.0) + + +class InferenceClient: + """Actor-side handle for an :class:`InferenceServer`. + + Wraps a transport's :meth:`~InferenceTransport.submit` so that calling + ``client(td)`` looks like a regular synchronous policy call, while the + actual computation is batched on the server. + + Args: + transport (InferenceTransport): the transport shared with the server. + + Example: + >>> client = transport.client() + >>> td_out = client(td_in) # blocking + >>> future = client.submit(td_in) # non-blocking + >>> td_out = future.result() + """ + + def __init__(self, transport: InferenceTransport): + self._transport = transport + + def __call__(self, td: TensorDictBase) -> TensorDictBase: + """Submit a request and block until the result is ready.""" + return self._transport.submit(td).result() + + def submit(self, td: TensorDictBase) -> Future[TensorDictBase]: + """Submit a request and return a Future immediately.""" + return self._transport.submit(td) diff --git a/torchrl/modules/inference_server/_transport.py b/torchrl/modules/inference_server/_transport.py new file mode 100644 index 00000000000..639a029d9eb --- /dev/null +++ b/torchrl/modules/inference_server/_transport.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import abc +from concurrent.futures import Future + +from tensordict.base import TensorDictBase + + +class InferenceTransport(abc.ABC): + """Abstract base class for inference server transport backends. + + A transport handles the communication between actor-side clients and the + server-side inference loop. Concrete implementations provide the mechanism + for submitting requests, draining batches, and routing results back. + + Subclasses must implement :meth:`submit`, :meth:`drain`, :meth:`wait_for_work`, + and :meth:`resolve`. + """ + + @abc.abstractmethod + def submit(self, td: TensorDictBase) -> Future[TensorDictBase]: + """Submit a single inference request. + + Called on the actor side. Returns a :class:`~concurrent.futures.Future` + (or future-like object) that will be resolved with the inference result. + + Args: + td (TensorDictBase): a single (unbatched) input tensordict. + + Returns: + A Future that resolves to the output TensorDictBase. + """ + ... + + @abc.abstractmethod + def drain(self, max_items: int) -> tuple[list[TensorDictBase], list]: + """Drain up to *max_items* pending requests from the queue. + + Called on the server side. Returns a pair ``(inputs, callbacks)`` where + ``inputs`` is a list of TensorDicts and ``callbacks`` is a list of + opaque objects that :meth:`resolve` knows how to fulfil. + + Args: + max_items (int): maximum number of items to dequeue. + + Returns: + Tuple of (inputs, callbacks). + """ + ... + + @abc.abstractmethod + def wait_for_work(self, timeout: float) -> None: + """Block until new work is available or *timeout* seconds elapse. + + Called on the server side before :meth:`drain`. + + Args: + timeout (float): maximum seconds to wait. + """ + ... + + @abc.abstractmethod + def resolve(self, callback, result: TensorDictBase) -> None: + """Send a result back to the actor that submitted the request. + + Args: + callback: an opaque handle returned by :meth:`drain`. + result (TensorDictBase): the inference output for this request. + """ + ... + + @abc.abstractmethod + def resolve_exception(self, callback, exc: BaseException) -> None: + """Propagate an exception back to the actor that submitted the request. + + Args: + callback: an opaque handle returned by :meth:`drain`. + exc (BaseException): the exception to propagate. + """ + ... + + def client(self) -> InferenceClient: # noqa: F821 + """Return an actor-side :class:`InferenceClient` bound to this transport.""" + from torchrl.modules.inference_server._server import InferenceClient + + return InferenceClient(self)