44# LICENSE file in the root directory of this source tree.
55from __future__ import annotations
66
7- import queue
87import threading
9- import time
10- from typing import Any
118
12- from tensordict .base import TensorDictBase
9+ from torchrl .modules .inference_server ._queue_transport import (
10+ _QueueInferenceClient ,
11+ QueueBasedTransport ,
12+ )
1313
14- from torchrl .modules .inference_server ._transport import InferenceTransport
1514
16- _SENTINEL = object ()
15+ class _MonarchRequestQueue :
16+ """Wrapper around ``MonarchQueue`` that signals a :class:`threading.Event` on put.
1717
18+ Also adapts the Monarch queue API (``get(block=False)``) to the standard
19+ ``get_nowait()`` expected by :class:`QueueBasedTransport`.
20+ """
1821
19- class _MonarchFuture :
20- """Future-like object for Monarch transport results.
22+ def __init__ (self , monarch_queue , has_work : threading .Event ):
23+ self ._queue = monarch_queue
24+ self ._has_work = has_work
2125
22- Args:
23- client: the :class:`_MonarchInferenceClient` that created this future.
24- req_id: the unique request identifier within that client.
25- """
26+ def put (self , item ):
27+ self ._queue .put (item )
28+ self ._has_work .set ()
2629
27- def __init__ (self , client : _MonarchInferenceClient , req_id : int ):
28- self ._client = client
29- self ._req_id = req_id
30- self ._result : Any = _SENTINEL
30+ def get (self , timeout = None ):
31+ return self ._queue .get (timeout = timeout )
3132
32- def done (self ) -> bool :
33- """Return ``True`` if the result is available without blocking."""
34- if self ._result is not _SENTINEL :
35- return True
36- try :
37- self ._result = self ._client ._get_result (self ._req_id , timeout = 0 )
38- except queue .Empty :
39- return False
40- return True
41-
42- def result (self , timeout : float | None = None ) -> TensorDictBase :
43- """Block until the result is available."""
44- if self ._result is _SENTINEL :
45- self ._result = self ._client ._get_result (self ._req_id , timeout = timeout )
46- if isinstance (self ._result , BaseException ):
47- raise self ._result
48- return self ._result
49-
50-
51- class _MonarchInferenceClient :
52- """Actor-side client for :class:`MonarchTransport`.
53-
54- Each client owns a dedicated response queue and routes results by
55- request-id.
56-
57- Args:
58- request_queue: the shared Monarch queue for requests.
59- response_queue: this client's dedicated response queue.
60- actor_id: the unique identifier assigned by the transport.
61- """
33+ def get_nowait (self ):
34+ return self ._queue .get (block = False )
35+
36+
37+ class _MonarchResponseQueue :
38+ """Thin wrapper adapting the MonarchQueue get API."""
39+
40+ def __init__ (self , monarch_queue ):
41+ self ._queue = monarch_queue
6242
63- def __init__ (self , request_queue , response_queue , actor_id : int ):
64- self ._request_queue = request_queue
65- self ._response_queue = response_queue
66- self ._actor_id = actor_id
67- self ._next_req_id = 0
68- self ._buffered : dict [int , Any ] = {}
69-
70- def __call__ (self , td : TensorDictBase ) -> TensorDictBase :
71- """Submit a request and block until the result is ready."""
72- return self .submit (td ).result ()
73-
74- def submit (self , td : TensorDictBase ) -> _MonarchFuture :
75- """Submit a request and return a :class:`_MonarchFuture`."""
76- req_id = self ._next_req_id
77- self ._next_req_id += 1
78- self ._request_queue .put ((self ._actor_id , req_id , td ))
79- return _MonarchFuture (self , req_id )
80-
81- # -- internal -------------------------------------------------------------
82-
83- def _get_result (self , req_id : int , timeout : float | None = None ) -> Any :
84- """Return the result for *req_id*, buffering any earlier arrivals."""
85- if req_id in self ._buffered :
86- return self ._buffered .pop (req_id )
87- deadline = None if timeout is None else time .monotonic () + timeout
88- while True :
89- remaining = None
90- if deadline is not None :
91- remaining = deadline - time .monotonic ()
92- if remaining <= 0 :
93- raise queue .Empty (f"Timeout waiting for result of request { req_id } " )
94- try :
95- rid , result = self ._response_queue .get (timeout = remaining )
96- except Exception :
97- raise queue .Empty (f"Timeout waiting for result of request { req_id } " )
98- if rid == req_id :
99- return result
100- self ._buffered [rid ] = result
101-
102-
103- class MonarchTransport (InferenceTransport ):
43+ def put (self , item ):
44+ self ._queue .put (item )
45+
46+ def get (self , timeout = None ):
47+ return self ._queue .get (timeout = timeout )
48+
49+
50+ class MonarchTransport (QueueBasedTransport ):
10451 """Transport using Monarch for distributed inference on GPU clusters.
10552
10653 Uses Monarch's actor model and RDMA-capable channels for efficient
@@ -118,6 +65,7 @@ class MonarchTransport(InferenceTransport):
11865 """
11966
12067 def __init__ (self , * , max_queue_size : int = 1000 ):
68+ super ().__init__ ()
12169 try :
12270 import monarch # noqa: F401
12371 from monarch .tools .queue import MonarchQueue
@@ -126,66 +74,21 @@ def __init__(self, *, max_queue_size: int = 1000):
12674 "Monarch is required for MonarchTransport. "
12775 "Install it following the Monarch documentation."
12876 )
129- self ._request_queue = MonarchQueue (maxsize = max_queue_size )
130- self ._response_queues : dict [int , Any ] = {}
131- self ._lock = threading .Lock ()
132- self ._next_actor_id = 0
77+ self ._has_work = threading .Event ()
78+ self ._request_queue = _MonarchRequestQueue (
79+ MonarchQueue (maxsize = max_queue_size ), self ._has_work
80+ )
81+ self ._response_queues : dict [int , _MonarchResponseQueue ] = {}
13382 self ._MonarchQueue = MonarchQueue
13483
135- # -- actor API ------------------------------------------------------------
84+ def _make_response_queue (self ) -> _MonarchResponseQueue :
85+ return _MonarchResponseQueue (self ._MonarchQueue (maxsize = 1000 ))
13686
137- def client (self ) -> _MonarchInferenceClient :
87+ def client (self ) -> _QueueInferenceClient :
13888 """Create an actor-side client with a dedicated response queue.
13989
14090 Returns:
141- A :class:`_MonarchInferenceClient ` that can be passed to a Monarch
91+ A :class:`_QueueInferenceClient ` that can be passed to a Monarch
14292 actor.
14393 """
144- with self ._lock :
145- actor_id = self ._next_actor_id
146- self ._next_actor_id += 1
147- response_queue = self ._MonarchQueue (maxsize = 1000 )
148- self ._response_queues [actor_id ] = response_queue
149- return _MonarchInferenceClient (self ._request_queue , response_queue , actor_id )
150-
151- def submit (self , td : TensorDictBase ):
152- """Not supported -- use :meth:`client` to obtain an actor handle."""
153- raise RuntimeError (
154- "MonarchTransport.submit() is not supported. "
155- "Call transport.client() to create a _MonarchInferenceClient."
156- )
157-
158- # -- server API -----------------------------------------------------------
159-
160- def drain (
161- self , max_items : int
162- ) -> tuple [list [TensorDictBase ], list [tuple [int , int ]]]:
163- """Dequeue up to *max_items* pending requests (non-blocking)."""
164- items : list [TensorDictBase ] = []
165- callbacks : list [tuple [int , int ]] = []
166- for _ in range (max_items ):
167- try :
168- actor_id , req_id , td = self ._request_queue .get (block = False )
169- items .append (td )
170- callbacks .append ((actor_id , req_id ))
171- except Exception :
172- break
173- return items , callbacks
174-
175- def wait_for_work (self , timeout : float ) -> None :
176- """Block until at least one request is available or *timeout* elapses."""
177- try :
178- item = self ._request_queue .get (timeout = timeout )
179- self ._request_queue .put (item )
180- except Exception :
181- pass
182-
183- def resolve (self , callback : tuple [int , int ], result : TensorDictBase ) -> None :
184- """Route the result to the correct actor's response queue."""
185- actor_id , req_id = callback
186- self ._response_queues [actor_id ].put ((req_id , result ))
187-
188- def resolve_exception (self , callback : tuple [int , int ], exc : BaseException ) -> None :
189- """Route an exception to the correct actor's response queue."""
190- actor_id , req_id = callback
191- self ._response_queues [actor_id ].put ((req_id , exc ))
94+ return super ().client ()
0 commit comments