55from __future__ import annotations
66
77import multiprocessing as mp
8- import queue
98import threading
10- import time
11- from typing import Any
129
13- from tensordict .base import TensorDictBase
10+ from torchrl .modules .inference_server ._queue_transport import (
11+ _QueueInferenceClient ,
12+ QueueBasedTransport ,
13+ )
1414
15- from torchrl .modules .inference_server ._transport import InferenceTransport
1615
17- _SENTINEL = object ()
16+ class _MPRequestQueue :
17+ """Wrapper around ``mp.Queue`` that signals a :class:`threading.Event` on put.
1818
19-
20- class _MPFuture :
21- """Future-like object backed by a per-actor response queue.
22-
23- The future retrieves its result by request-id so that out-of-order
24- ``result()`` calls work correctly.
25-
26- Args:
27- client: the :class:`_MPInferenceClient` that created this future.
28- req_id: the unique request identifier within that client.
19+ This avoids the get-then-put anti-pattern in ``wait_for_work``: instead of
20+ consuming an item just to peek, callers wait on the event.
2921 """
3022
31- def __init__ (self , client : _MPInferenceClient , req_id : int ):
32- self ._client = client
33- self ._req_id = req_id
34- self ._result : Any = _SENTINEL
35-
36- def done (self ) -> bool :
37- """Return ``True`` if the result is available without blocking."""
38- if self ._result is not _SENTINEL :
39- return True
40- try :
41- self ._result = self ._client ._get_result (self ._req_id , timeout = 0 )
42- except queue .Empty :
43- return False
44- return True
45-
46- def result (self , timeout : float | None = None ) -> TensorDictBase :
47- """Block until the result is available.
48-
49- Args:
50- timeout: seconds to wait. ``None`` waits indefinitely.
51-
52- Raises:
53- queue.Empty: if *timeout* expires before a result arrives.
54- Exception: if the server set an exception instead of a result.
55- """
56- if self ._result is _SENTINEL :
57- self ._result = self ._client ._get_result (self ._req_id , timeout = timeout )
58- if isinstance (self ._result , BaseException ):
59- raise self ._result
60- return self ._result
23+ def __init__ (self , ctx : mp .context .BaseContext , has_work : threading .Event ):
24+ self ._queue : mp .Queue = ctx .Queue ()
25+ self ._has_work = has_work
6126
27+ def put (self , item ):
28+ self ._queue .put (item )
29+ self ._has_work .set ()
6230
63- class _MPInferenceClient :
64- """Actor-side client for :class:`MPTransport`.
31+ def get ( self , timeout = None ) :
32+ return self . _queue . get ( timeout = timeout )
6533
66- Each client owns a dedicated response queue and routes results by
67- request-id. Instances are created by :meth:`MPTransport.client` and
68- must be created **before** spawning child processes so that the
69- underlying queues are inherited.
34+ def get_nowait (self ):
35+ return self ._queue .get_nowait ()
7036
71- Args:
72- request_queue: the shared request queue.
73- response_queue: this client's dedicated response queue.
74- actor_id: the unique identifier assigned by the transport.
75- """
7637
77- def __init__ (
78- self ,
79- request_queue : mp .Queue ,
80- response_queue : mp .Queue ,
81- actor_id : int ,
82- ):
83- self ._request_queue = request_queue
84- self ._response_queue = response_queue
85- self ._actor_id = actor_id
86- self ._next_req_id = 0
87- self ._buffered : dict [int , Any ] = {}
88-
89- def __call__ (self , td : TensorDictBase ) -> TensorDictBase :
90- """Submit a request and block until the result is ready."""
91- return self .submit (td ).result ()
92-
93- def submit (self , td : TensorDictBase ) -> _MPFuture :
94- """Submit a request and return an :class:`_MPFuture`."""
95- req_id = self ._next_req_id
96- self ._next_req_id += 1
97- self ._request_queue .put ((self ._actor_id , req_id , td ))
98- return _MPFuture (self , req_id )
99-
100- # -- internal -------------------------------------------------------------
101-
102- def _get_result (self , req_id : int , timeout : float | None = None ) -> Any :
103- """Return the result for *req_id*, buffering any earlier arrivals."""
104- if req_id in self ._buffered :
105- return self ._buffered .pop (req_id )
106- deadline = None if timeout is None else time .monotonic () + timeout
107- while True :
108- remaining = None
109- if deadline is not None :
110- remaining = deadline - time .monotonic ()
111- if remaining <= 0 :
112- raise queue .Empty (f"Timeout waiting for result of request { req_id } " )
113- rid , result = self ._response_queue .get (timeout = remaining )
114- if rid == req_id :
115- return result
116- self ._buffered [rid ] = result
117-
118-
119- class MPTransport (InferenceTransport ):
38+ class MPTransport (QueueBasedTransport ):
12039 """Cross-process transport using :mod:`multiprocessing` queues.
12140
12241 Response routing uses per-actor queues (one per :meth:`client` call) so
@@ -137,69 +56,22 @@ class MPTransport(InferenceTransport):
13756 """
13857
13958 def __init__ (self , ctx : mp .context .BaseContext | None = None ):
59+ super ().__init__ ()
14060 self ._ctx = ctx if ctx is not None else mp .get_context ("spawn" )
141- self ._request_queue : mp .Queue = self ._ctx .Queue ()
61+ self ._has_work = threading .Event ()
62+ self ._request_queue = _MPRequestQueue (self ._ctx , self ._has_work )
14263 self ._response_queues : dict [int , mp .Queue ] = {}
143- self ._lock = threading .Lock ()
144- self ._next_actor_id = 0
14564
146- # -- actor API (called before fork) ---------------------------------------
65+ def _make_response_queue (self ) -> mp .Queue :
66+ return self ._ctx .Queue ()
14767
148- def client (self ) -> _MPInferenceClient :
68+ def client (self ) -> _QueueInferenceClient :
14969 """Create an actor-side client with a dedicated response queue.
15070
15171 Must be called in the parent process **before** spawning children.
15272
15373 Returns:
154- An :class:`_MPInferenceClient ` that can be passed to a child
74+ A :class:`_QueueInferenceClient ` that can be passed to a child
15575 process as an argument to :class:`multiprocessing.Process`.
15676 """
157- with self ._lock :
158- actor_id = self ._next_actor_id
159- self ._next_actor_id += 1
160- response_queue : mp .Queue = self ._ctx .Queue ()
161- self ._response_queues [actor_id ] = response_queue
162- return _MPInferenceClient (self ._request_queue , response_queue , actor_id )
163-
164- def submit (self , td : TensorDictBase ):
165- """Not supported -- use :meth:`client` to obtain an actor handle."""
166- raise RuntimeError (
167- "MPTransport.submit() is not supported. "
168- "Call transport.client() to create an _MPInferenceClient."
169- )
170-
171- # -- server API -----------------------------------------------------------
172-
173- def drain (
174- self , max_items : int
175- ) -> tuple [list [TensorDictBase ], list [tuple [int , int ]]]:
176- """Dequeue up to *max_items* pending ``(actor_id, req_id, td)`` tuples."""
177- items : list [TensorDictBase ] = []
178- callbacks : list [tuple [int , int ]] = []
179- for _ in range (max_items ):
180- try :
181- actor_id , req_id , td = self ._request_queue .get_nowait ()
182- items .append (td )
183- callbacks .append ((actor_id , req_id ))
184- except queue .Empty :
185- break
186- return items , callbacks
187-
188- def wait_for_work (self , timeout : float ) -> None :
189- """Block until at least one request is available or *timeout* elapses."""
190- try :
191- item = self ._request_queue .get (timeout = timeout )
192- # Put it back so drain() can consume it.
193- self ._request_queue .put (item )
194- except queue .Empty :
195- pass
196-
197- def resolve (self , callback : tuple [int , int ], result : TensorDictBase ) -> None :
198- """Route the result to the correct actor's response queue."""
199- actor_id , req_id = callback
200- self ._response_queues [actor_id ].put ((req_id , result ))
201-
202- def resolve_exception (self , callback : tuple [int , int ], exc : BaseException ) -> None :
203- """Route an exception to the correct actor's response queue."""
204- actor_id , req_id = callback
205- self ._response_queues [actor_id ].put ((req_id , exc ))
77+ return super ().client ()
0 commit comments