@@ -74,7 +74,7 @@ class _QueueInferenceClient:
7474 request-id.
7575
7676 Args:
77- request_queue: the shared request queue.
77+ request_queue: the shared request queue (any object with ``.put()``) .
7878 response_queue: this client's dedicated response queue.
7979 actor_id: the unique identifier assigned by the transport.
8080 """
@@ -122,18 +122,23 @@ def _get_result(self, req_id: int, timeout: float | None = None) -> Any:
122122class QueueBasedTransport (InferenceTransport ):
123123 """Base class for transports that use a request queue and per-actor response queues.
124124
125- Subclasses must set the following attributes before calling ``super().__init__()``:
125+ Subclasses must set the following attributes in ``__init__`` (before or
126+ after calling ``super().__init__()``):
126127
127- * ``_request_queue`` -- the shared request queue (any object with ``.put()``,
128- ``.get(timeout=... )``, and ``.get_nowait( )`` / ``.get(block=False)``).
128+ * ``_request_queue`` -- the shared request queue (any object with
129+ ``.put( )``, ``.get(timeout=... )``, and ``.get(block=False)``).
129130 * ``_response_queues`` -- a ``dict[int, <queue>]`` mapping actor ids to
130131 per-actor response queues.
131- * ``_has_work`` -- a :class:`threading.Event` that is set whenever a new
132- request is enqueued (used for non-blocking ``wait_for_work``).
133132
134133 Subclasses must implement:
135134
136135 * :meth:`_make_response_queue` -- factory for creating a new response queue.
136+
137+ .. note::
138+ ``wait_for_work`` uses a blocking ``get`` followed by ``put`` to peek
139+ at the request queue. This is safe because a single server thread
140+ calls both ``wait_for_work`` and ``drain`` sequentially -- there is no
141+ concurrent consumer that could miss the re-enqueued item.
137142 """
138143
139144 def __init__ (self ):
@@ -178,7 +183,7 @@ def drain(
178183 callbacks : list [tuple [int , int ]] = []
179184 for _ in range (max_items ):
180185 try :
181- actor_id , req_id , td = self ._request_queue .get_nowait ( )
186+ actor_id , req_id , td = self ._request_queue .get ( block = False )
182187 except Exception :
183188 break
184189 items .append (td )
@@ -187,8 +192,12 @@ def drain(
187192
188193 def wait_for_work (self , timeout : float ) -> None :
189194 """Block until at least one request is available or *timeout* elapses."""
190- self ._has_work .wait (timeout = timeout )
191- self ._has_work .clear ()
195+ try :
196+ item = self ._request_queue .get (timeout = timeout )
197+ # Put it back so drain() can consume it.
198+ self ._request_queue .put (item )
199+ except Exception :
200+ pass
192201
193202 def resolve (self , callback : tuple [int , int ], result : TensorDictBase ) -> None :
194203 """Route the result to the correct actor's response queue."""
0 commit comments