|
25 | 25 | from typing import Dict, Optional, Tuple, List, Any |
26 | 26 | import json |
27 | 27 | import logging |
| 28 | +import os |
28 | 29 | import pprint |
29 | 30 | import pickle |
30 | 31 | from pathlib import Path |
@@ -134,7 +135,10 @@ def __init__(self, end_point: str) -> ModelServer: |
134 | 135 | logging.debug( |
135 | 136 | f"Python model trying to connect to manager at {end_point}") |
136 | 137 | self.socket.connect(f"ipc://{end_point}") |
137 | | - logging.debug(f"Python model connected at {end_point}") |
| 138 | + logging.info(f"Python model connected at {end_point}") |
| 139 | + |
| 140 | + # If the ModelServer is closing |
| 141 | + self._closing = False |
138 | 142 |
|
139 | 143 | # Register the exit callback |
140 | 144 | atexit.register(self.cleanup_zmq) |
@@ -331,6 +335,19 @@ def _infer(self, data: Dict) -> Tuple[List, bool, str]: |
331 | 335 |
|
332 | 336 | return y_pred.tolist(), True, "" |
333 | 337 |
|
| 338 | + def _recv(self) -> str: |
| 339 | + """ |
| 340 | + Receive from the ZMQ socket. This is a blocking call. |
| 341 | +
|
| 342 | + :return: Message paylod |
| 343 | + """ |
| 344 | + identity = self.socket.recv() |
| 345 | + _delim = self.socket.recv() |
| 346 | + payload = self.socket.recv() |
| 347 | + logging.debug(f"Python recv: {str(identity)}, {str(payload)}") |
| 348 | + |
| 349 | + return payload.decode("ascii") |
| 350 | + |
334 | 351 | def _execute_cmd(self, cmd: Command, data: Dict) -> Tuple[Dict, bool]: |
335 | 352 | """ |
336 | 353 | Execute a command from the ModelServerManager |
@@ -382,12 +399,20 @@ def run_loop(self): |
382 | 399 | """ |
383 | 400 |
|
384 | 401 | while(1): |
385 | | - identity = self.socket.recv() |
386 | | - _delim = self.socket.recv() |
387 | | - payload = self.socket.recv() |
388 | | - logging.debug(f"Python recv: {str(identity)}, {str(payload)}") |
| 402 | + try: |
| 403 | + payload = self._recv() |
| 404 | + except UnicodeError as e: |
| 405 | + logging.warning(f"Failed to decode : {e.reason}") |
| 406 | + continue |
| 407 | + except KeyboardInterrupt: |
| 408 | + if self._closing: |
| 409 | + logging.warning("Forced shutting down now.") |
| 410 | + os._exit(-1) |
| 411 | + else: |
| 412 | + logging.info("Received KeyboardInterrupt. Ctrl+C again to force shutting down.") |
| 413 | + self._closing = True |
| 414 | + continue |
389 | 415 |
|
390 | | - payload = payload.decode("ascii") |
391 | 416 | send_id, recv_id, msg = self._parse_msg(payload) |
392 | 417 | if msg is None: |
393 | 418 | continue |
|
0 commit comments