|
14 | 14 | # limitations under the License.
|
15 | 15 | """
|
16 | 16 |
|
17 |
| -import os |
18 |
| -import threading |
19 |
| -import time |
20 |
| -import traceback |
| 17 | +from abc import ABC, abstractmethod |
21 | 18 |
|
22 |
| -import msgpack |
23 | 19 | import zmq
|
24 | 20 |
|
25 |
| -from fastdeploy import envs |
26 |
| -from fastdeploy.utils import zmq_client_logger |
27 | 21 |
|
28 |
| - |
29 |
| -class ZmqClient: |
| 22 | +class ZmqClientBase(ABC): |
30 | 23 | """
|
31 |
| - ZmqClient is a class that provides a client-side interface for sending and receiving messages using ZeroMQ. |
| 24 | + ZmqClientBase is a base class that provides a client-side interface for sending and receiving messages using ZeroMQ. |
32 | 25 | """
|
33 | 26 |
|
34 |
| - def __init__(self, name, mode): |
35 |
| - self.context = zmq.Context(4) |
36 |
| - self.socket = self.context.socket(mode) |
37 |
| - self.file_name = f"/dev/shm/{name}.socket" |
38 |
| - self.router_path = f"/dev/shm/router_{name}.ipc" |
| 27 | + def __init__(self): |
| 28 | + pass |
39 | 29 |
|
40 |
| - self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) |
41 |
| - self.aggregate_send = envs.FD_USE_AGGREGATE_SEND |
| 30 | + @abstractmethod |
| 31 | + def _create_socket(self): |
| 32 | + """Abstract method to create and return a ZeroMQ socket.""" |
| 33 | + pass |
42 | 34 |
|
43 |
| - self.mutex = threading.Lock() |
44 |
| - self.req_dict = dict() |
45 |
| - self.router = None |
46 |
| - self.poller = None |
47 |
| - self.running = True |
| 35 | + def _ensure_socket(self): |
| 36 | + """Ensure the socket is created before use.""" |
| 37 | + if self.socket is None: |
| 38 | + self.socket = self._create_socket() |
48 | 39 |
|
| 40 | + @abstractmethod |
49 | 41 | def connect(self):
|
50 | 42 | """
|
51 | 43 | Connect to the server using the file name specified in the constructor.
|
52 | 44 | """
|
53 |
| - self.socket.connect(f"ipc://{self.file_name}") |
54 |
| - |
55 |
| - def start_server(self): |
56 |
| - """ |
57 |
| - Start the server using the file name specified in the constructor. |
58 |
| - """ |
59 |
| - self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) |
60 |
| - self.socket.setsockopt(zmq.SNDTIMEO, -1) |
61 |
| - self.socket.bind(f"ipc://{self.file_name}") |
62 |
| - self.poller = zmq.Poller() |
63 |
| - self.poller.register(self.socket, zmq.POLLIN) |
64 |
| - |
65 |
| - def create_router(self): |
66 |
| - """ |
67 |
| - Create a ROUTER socket and bind it to the specified router path. |
68 |
| - """ |
69 |
| - self.router = self.context.socket(zmq.ROUTER) |
70 |
| - self.router.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) |
71 |
| - self.router.setsockopt(zmq.ROUTER_MANDATORY, 1) |
72 |
| - self.router.setsockopt(zmq.SNDTIMEO, -1) |
73 |
| - self.router.bind(f"ipc://{self.router_path}") |
74 |
| - zmq_client_logger.info(f"router path: {self.router_path}") |
| 45 | + pass |
75 | 46 |
|
76 | 47 | def send_json(self, data):
|
77 | 48 | """
|
78 | 49 | Send a JSON-serializable object over the socket.
|
79 | 50 | """
|
| 51 | + self._ensure_socket() |
80 | 52 | self.socket.send_json(data)
|
81 | 53 |
|
82 | 54 | def recv_json(self):
|
83 | 55 | """
|
84 | 56 | Receive a JSON-serializable object from the socket.
|
85 | 57 | """
|
| 58 | + self._ensure_socket() |
86 | 59 | return self.socket.recv_json()
|
87 | 60 |
|
88 | 61 | def send_pyobj(self, data):
|
89 | 62 | """
|
90 | 63 | Send a Pickle-serializable object over the socket.
|
91 | 64 | """
|
| 65 | + self._ensure_socket() |
92 | 66 | self.socket.send_pyobj(data)
|
93 | 67 |
|
94 | 68 | def recv_pyobj(self):
|
95 | 69 | """
|
96 | 70 | Receive a Pickle-serializable object from the socket.
|
97 | 71 | """
|
| 72 | + self._ensure_socket() |
98 | 73 | return self.socket.recv_pyobj()
|
99 | 74 |
|
100 |
| - def pack_aggregated_data(self, data): |
101 |
| - """ |
102 |
| - Aggregate multiple responses into one and send them to the client. |
103 |
| - """ |
104 |
| - result = data[0] |
105 |
| - if len(data) > 1: |
106 |
| - for response in data[1:]: |
107 |
| - result.add(response) |
108 |
| - result = msgpack.packb([result.to_dict()]) |
109 |
| - return result |
110 |
| - |
111 |
| - def send_multipart(self, req_id, data): |
112 |
| - """ |
113 |
| - Send a multipart message to the router socket. |
114 |
| - """ |
115 |
| - if self.router is None: |
116 |
| - raise RuntimeError("Router socket not created. Call create_router() first.") |
117 |
| - |
118 |
| - while self.running: |
119 |
| - with self.mutex: |
120 |
| - if req_id not in self.req_dict: |
121 |
| - try: |
122 |
| - client, _, request_id = self.router.recv_multipart(flags=zmq.NOBLOCK) |
123 |
| - req_id_str = request_id.decode("utf-8") |
124 |
| - self.req_dict[req_id_str] = client |
125 |
| - except zmq.Again: |
126 |
| - time.sleep(0.001) |
127 |
| - continue |
128 |
| - else: |
129 |
| - break |
130 |
| - if self.req_dict[req_id] == -1: |
131 |
| - if data[-1].finished: |
132 |
| - with self.mutex: |
133 |
| - self.req_dict.pop(req_id, None) |
134 |
| - return |
135 |
| - try: |
136 |
| - start_send = time.time() |
137 |
| - if self.aggregate_send: |
138 |
| - result = self.pack_aggregated_data(data) |
139 |
| - else: |
140 |
| - result = msgpack.packb([response.to_dict() for response in data]) |
141 |
| - self.router.send_multipart([self.req_dict[req_id], b"", result]) |
142 |
| - zmq_client_logger.info(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}") |
143 |
| - except zmq.ZMQError as e: |
144 |
| - zmq_client_logger.error(f"[{req_id}] zmq error: {e}") |
145 |
| - self.req_dict[req_id] = -1 |
146 |
| - except Exception as e: |
147 |
| - zmq_client_logger.error(f"Send result to zmq client failed: {e}, {str(traceback.format_exc())}") |
148 |
| - |
149 |
| - if data[-1].finished: |
150 |
| - with self.mutex: |
151 |
| - self.req_dict.pop(req_id, None) |
152 |
| - zmq_client_logger.info(f"send_multipart finished, req_id: {req_id}") |
153 |
| - |
154 |
| - def receive_json_once(self, block=False): |
155 |
| - """ |
156 |
| - Receive a single message from the socket. |
157 |
| - """ |
158 |
| - if self.socket is None or self.socket.closed: |
159 |
| - return "zmp socket has closed", None |
160 |
| - try: |
161 |
| - flags = zmq.NOBLOCK if not block else 0 |
162 |
| - return None, self.socket.recv_json(flags=flags) |
163 |
| - except zmq.Again: |
164 |
| - return None, None |
165 |
| - except Exception as e: |
166 |
| - self.close() |
167 |
| - zmq_client_logger.warning(f"{e}, {str(traceback.format_exc())}") |
168 |
| - return str(e), None |
169 |
| - |
170 |
| - def receive_pyobj_once(self, block=False): |
171 |
| - """ |
172 |
| - Receive a single message from the socket. |
173 |
| - """ |
174 |
| - if self.socket is None or self.socket.closed: |
175 |
| - return "zmp socket has closed", None |
176 |
| - try: |
177 |
| - flags = zmq.NOBLOCK if not block else 0 |
178 |
| - return None, self.socket.recv_pyobj(flags=flags) |
179 |
| - except zmq.Again: |
180 |
| - return None, None |
181 |
| - except Exception as e: |
182 |
| - self.close() |
183 |
| - zmq_client_logger.warning(f"{e}, {str(traceback.format_exc())}") |
184 |
| - return str(e), None |
185 |
| - |
186 |
| - def _clear_ipc(self, name): |
187 |
| - """ |
188 |
| - Remove the IPC file with the given name. |
189 |
| - """ |
190 |
| - if os.path.exists(name): |
191 |
| - try: |
192 |
| - os.remove(name) |
193 |
| - except OSError as e: |
194 |
| - zmq_client_logger.warning(f"Failed to remove IPC file {name} - {e}") |
195 |
| - |
196 |
| - def close(self): |
197 |
| - """ |
198 |
| - Close the socket and context, and remove the IPC files. |
199 |
| - """ |
200 |
| - if not self.running: |
201 |
| - return |
202 |
| - |
203 |
| - self.running = False |
204 |
| - zmq_client_logger.info("Closing ZMQ connection...") |
205 |
| - try: |
206 |
| - if hasattr(self, "socket") and not self.socket.closed: |
207 |
| - self.socket.close() |
208 | 75 |
|
209 |
| - if self.router is not None and not self.router.closed: |
210 |
| - self.router.close() |
211 |
| - |
212 |
| - if not self.context.closed: |
213 |
| - self.context.term() |
| 76 | +class ZmqIpcClient(ZmqClientBase): |
| 77 | + def __init__(self, name, mode): |
| 78 | + self.name = name |
| 79 | + self.mode = mode |
| 80 | + self.file_name = f"/dev/shm/{name}.socket" |
| 81 | + self.context = zmq.Context() |
| 82 | + self.socket = self.context.socket(self.mode) |
214 | 83 |
|
215 |
| - self._clear_ipc(self.file_name) |
216 |
| - self._clear_ipc(self.router_path) |
217 |
| - except Exception as e: |
218 |
| - zmq_client_logger.warning(f"Failed to close ZMQ connection - {e}, {str(traceback.format_exc())}") |
219 |
| - return |
| 84 | + def _create_socket(self): |
| 85 | + """create and return a ZeroMQ socket.""" |
| 86 | + self.context = zmq.Context() |
| 87 | + return self.context.socket(self.mode) |
220 | 88 |
|
221 |
| - def __exit__(self, exc_type, exc_val, exc_tb): |
222 |
| - self.close() |
| 89 | + def connect(self): |
| 90 | + self._ensure_socket() |
| 91 | + self.socket.connect(f"ipc://{self.file_name}") |
0 commit comments