|
13 | 13 | # |
14 | 14 | # You should have received a copy of the GNU General Public License |
15 | 15 | # along with this program. If not, see <http://www.gnu.org/licenses/>. |
16 | | -"""Package for network interfaces to Cylc scheduler objects.""" |
17 | 16 |
|
18 | | -import asyncio |
19 | | -import getpass |
20 | | -import json |
21 | | -from typing import Optional, Tuple |
| 17 | +"""Cylc networking code. |
22 | 18 |
|
23 | | -import zmq |
24 | | -import zmq.asyncio |
25 | | -import zmq.auth |
| 19 | +Contains: |
| 20 | +* Server code (hosted by the scheduler process). |
| 21 | +* Client implementations (used to communicate with the scheduler). |
| 22 | +* Workflow scanning logic. |
| 23 | +* Schema and interface definitions. |
| 24 | +""" |
26 | 25 |
|
27 | | -from cylc.flow import LOG |
28 | | -from cylc.flow.exceptions import ( |
29 | | - ClientError, |
30 | | - CylcError, |
31 | | - CylcVersionError, |
32 | | - ServiceFileError, |
33 | | - WorkflowStopped |
34 | | -) |
35 | | -from cylc.flow.hostuserutil import get_fqdn_by_host |
36 | | -from cylc.flow.workflow_files import ( |
37 | | - ContactFileFields, |
38 | | - KeyType, |
39 | | - KeyOwner, |
40 | | - KeyInfo, |
41 | | - load_contact_file, |
42 | | - get_workflow_srv_dir |
43 | | -) |
44 | | - |
45 | | -API = 5 # cylc API version |
46 | | -MSG_TIMEOUT = "TIMEOUT" |
47 | | - |
48 | | - |
49 | | -def encode_(message): |
50 | | - """Convert the structure holding a message field from JSON to a string.""" |
51 | | - try: |
52 | | - return json.dumps(message) |
53 | | - except TypeError as exc: |
54 | | - return json.dumps({'errors': [{'message': str(exc)}]}) |
55 | | - |
56 | | - |
57 | | -def decode_(message): |
58 | | - """Convert an encoded message string to JSON with an added 'user' field.""" |
59 | | - msg = json.loads(message) |
60 | | - msg['user'] = getpass.getuser() # assume this is the user |
61 | | - return msg |
62 | | - |
63 | | - |
64 | | -def get_location(workflow: str) -> Tuple[str, int, int]: |
65 | | - """Extract host and port from a workflow's contact file. |
66 | | -
|
67 | | - NB: if it fails to load the workflow contact file, it will exit. |
68 | | -
|
69 | | - Args: |
70 | | - workflow: workflow ID |
71 | | - Returns: |
72 | | - Tuple (host name, port number, publish port number) |
73 | | - Raises: |
74 | | - WorkflowStopped: if the workflow is not running. |
75 | | - CylcVersionError: if target is a Cylc 7 (or earlier) workflow. |
76 | | - """ |
77 | | - try: |
78 | | - contact = load_contact_file(workflow) |
79 | | - except (IOError, ValueError, ServiceFileError): |
80 | | - # Contact file does not exist or corrupted, workflow should be dead |
81 | | - raise WorkflowStopped(workflow) |
82 | | - |
83 | | - host = contact[ContactFileFields.HOST] |
84 | | - host = get_fqdn_by_host(host) |
85 | | - port = int(contact[ContactFileFields.PORT]) |
86 | | - if ContactFileFields.PUBLISH_PORT in contact: |
87 | | - pub_port = int(contact[ContactFileFields.PUBLISH_PORT]) |
88 | | - else: |
89 | | - version = contact.get('CYLC_VERSION', None) |
90 | | - raise CylcVersionError(version=version) |
91 | | - return host, port, pub_port |
92 | | - |
93 | | - |
94 | | -class ZMQSocketBase: |
95 | | - """Initiate the ZMQ socket bind for specified pattern. |
96 | | -
|
97 | | - NOTE: Security to be provided via zmq.auth (see PR #3359). |
98 | | -
|
99 | | - Args: |
100 | | - pattern (enum): ZeroMQ message pattern (zmq.PATTERN). |
101 | | -
|
102 | | - context (object, optional): instantiated ZeroMQ context, defaults |
103 | | - to zmq.asyncio.Context(). |
104 | | -
|
105 | | - This class is designed to be inherited by REP Server (REQ/REP) |
106 | | - and by PUB Publisher (PUB/SUB), as the start-up logic is similar. |
107 | | -
|
108 | | -
|
109 | | - To tailor this class overwrite it's method on inheritance. |
110 | | -
|
111 | | - """ |
112 | | - |
113 | | - def __init__( |
114 | | - self, |
115 | | - pattern, |
116 | | - workflow: str, |
117 | | - bind: bool = False, |
118 | | - context: Optional[zmq.Context] = None, |
119 | | - ): |
120 | | - self.bind = bind |
121 | | - if context is None: |
122 | | - self.context: zmq.Context = zmq.asyncio.Context() |
123 | | - else: |
124 | | - self.context = context |
125 | | - self.pattern = pattern |
126 | | - self.workflow = workflow |
127 | | - self.host: Optional[str] = None |
128 | | - self.port: Optional[int] = None |
129 | | - self.socket: Optional[zmq.Socket] = None |
130 | | - self.loop: Optional[asyncio.AbstractEventLoop] = None |
131 | | - self.stopping = False |
132 | | - |
133 | | - def start(self, *args, **kwargs): |
134 | | - """Create the async loop, and bind socket.""" |
135 | | - # set asyncio loop |
136 | | - try: |
137 | | - self.loop = asyncio.get_running_loop() |
138 | | - except RuntimeError: |
139 | | - self.loop = asyncio.new_event_loop() |
140 | | - asyncio.set_event_loop(self.loop) |
141 | | - |
142 | | - if self.bind: |
143 | | - self._socket_bind(*args, **kwargs) |
144 | | - else: |
145 | | - self._socket_connect(*args, **kwargs) |
146 | | - |
147 | | - # initiate bespoke items |
148 | | - self._bespoke_start() |
149 | | - |
150 | | - # Keeping srv_prv_key_loc as optional arg so as to not break interface |
151 | | - def _socket_bind(self, min_port, max_port, srv_prv_key_loc=None): |
152 | | - """Bind socket. |
153 | | -
|
154 | | - Will use a port range provided to select random ports. |
155 | | -
|
156 | | - """ |
157 | | - if srv_prv_key_loc is None: |
158 | | - # Create new KeyInfo object for the server private key |
159 | | - workflow_srv_dir = get_workflow_srv_dir(self.workflow) |
160 | | - srv_prv_key_info = KeyInfo( |
161 | | - KeyType.PRIVATE, |
162 | | - KeyOwner.SERVER, |
163 | | - workflow_srv_dir=workflow_srv_dir) |
164 | | - else: |
165 | | - srv_prv_key_info = KeyInfo( |
166 | | - KeyType.PRIVATE, |
167 | | - KeyOwner.SERVER, |
168 | | - full_key_path=srv_prv_key_loc) |
169 | | - |
170 | | - # create socket |
171 | | - self.socket = self.context.socket(self.pattern) |
172 | | - self._socket_options() |
173 | | - |
174 | | - try: |
175 | | - server_public_key, server_private_key = zmq.auth.load_certificate( |
176 | | - srv_prv_key_info.full_key_path) |
177 | | - except ValueError: |
178 | | - raise ServiceFileError( |
179 | | - f"Failed to find server's public " |
180 | | - f"key in " |
181 | | - f"{srv_prv_key_info.full_key_path}." |
182 | | - ) |
183 | | - except OSError: |
184 | | - raise ServiceFileError( |
185 | | - f"IO error opening server's private " |
186 | | - f"key from " |
187 | | - f"{srv_prv_key_info.full_key_path}." |
188 | | - ) |
189 | | - if server_private_key is None: # this can't be caught by exception |
190 | | - raise ServiceFileError( |
191 | | - f"Failed to find server's private " |
192 | | - f"key in " |
193 | | - f"{srv_prv_key_info.full_key_path}." |
194 | | - ) |
195 | | - self.socket.curve_publickey = server_public_key |
196 | | - self.socket.curve_secretkey = server_private_key |
197 | | - self.socket.curve_server = True |
198 | | - |
199 | | - try: |
200 | | - if min_port == max_port: |
201 | | - self.port = min_port |
202 | | - self.socket.bind(f'tcp://*:{min_port}') |
203 | | - else: |
204 | | - self.port = self.socket.bind_to_random_port( |
205 | | - 'tcp://*', min_port, max_port) |
206 | | - except (zmq.error.ZMQError, zmq.error.ZMQBindError) as exc: |
207 | | - raise CylcError(f'could not start Cylc ZMQ server: {exc}') |
208 | | - |
209 | | - # Keeping srv_public_key_loc as optional arg so as to not break interface |
210 | | - def _socket_connect(self, host, port, srv_public_key_loc=None): |
211 | | - """Connect socket to stub.""" |
212 | | - workflow_srv_dir = get_workflow_srv_dir(self.workflow) |
213 | | - if srv_public_key_loc is None: |
214 | | - # Create new KeyInfo object for the server public key |
215 | | - srv_pub_key_info = KeyInfo( |
216 | | - KeyType.PUBLIC, |
217 | | - KeyOwner.SERVER, |
218 | | - workflow_srv_dir=workflow_srv_dir) |
219 | | - |
220 | | - else: |
221 | | - srv_pub_key_info = KeyInfo( |
222 | | - KeyType.PUBLIC, |
223 | | - KeyOwner.SERVER, |
224 | | - full_key_path=srv_public_key_loc) |
225 | | - |
226 | | - self.host = host |
227 | | - self.port = port |
228 | | - self.socket = self.context.socket(self.pattern) |
229 | | - self._socket_options() |
230 | | - |
231 | | - client_priv_key_info = KeyInfo( |
232 | | - KeyType.PRIVATE, |
233 | | - KeyOwner.CLIENT, |
234 | | - workflow_srv_dir=workflow_srv_dir) |
235 | | - error_msg = "Failed to find user's private key, so cannot connect." |
236 | | - try: |
237 | | - client_public_key, client_priv_key = zmq.auth.load_certificate( |
238 | | - client_priv_key_info.full_key_path) |
239 | | - except (OSError, ValueError): |
240 | | - raise ClientError(error_msg) |
241 | | - if client_priv_key is None: # this can't be caught by exception |
242 | | - raise ClientError(error_msg) |
243 | | - self.socket.curve_publickey = client_public_key |
244 | | - self.socket.curve_secretkey = client_priv_key |
245 | | - |
246 | | - # A client can only connect to the server if it knows its public key, |
247 | | - # so we grab this from the location it was created on the filesystem: |
248 | | - try: |
249 | | - # 'load_certificate' will try to load both public & private keys |
250 | | - # from a provided file but will return None, not throw an error, |
251 | | - # for the latter item if not there (as for all public key files) |
252 | | - # so it is OK to use; there is no method to load only the |
253 | | - # public key. |
254 | | - server_public_key = zmq.auth.load_certificate( |
255 | | - srv_pub_key_info.full_key_path)[0] |
256 | | - self.socket.curve_serverkey = server_public_key |
257 | | - except (OSError, ValueError): # ValueError raised w/ no public key |
258 | | - raise ClientError( |
259 | | - "Failed to load the workflow's public key, so cannot connect.") |
260 | | - |
261 | | - self.socket.connect(f'tcp://{host}:{port}') |
262 | | - |
263 | | - def _socket_options(self): |
264 | | - """Set socket options. |
265 | | -
|
266 | | - i.e. self.socket.sndhwm |
267 | | - """ |
268 | | - self.socket.sndhwm = 10000 |
269 | | - |
270 | | - def _bespoke_start(self): |
271 | | - """Initiate bespoke items at start.""" |
272 | | - self.stopping = False |
273 | | - |
274 | | - def stop(self, stop_loop=True): |
275 | | - """Stop the server. |
276 | | -
|
277 | | - Args: |
278 | | - stop_loop (Boolean): Stop running IOLoop. |
279 | | -
|
280 | | - """ |
281 | | - self._bespoke_stop() |
282 | | - if stop_loop and self.loop and self.loop.is_running(): |
283 | | - self.loop.stop() |
284 | | - if self.socket and not self.socket.closed: |
285 | | - self.socket.close() |
286 | | - LOG.debug('...stopped') |
287 | | - |
288 | | - def _bespoke_stop(self): |
289 | | - """Bespoke stop items.""" |
290 | | - LOG.debug('stopping zmq socket...') |
291 | | - self.stopping = True |
| 26 | +# Cylc API version. |
| 27 | +# This is the Cylc protocol version number that determines whether a client can |
| 28 | +# communicate with a server. This should be changed when breaking changes are |
| 29 | +# made for which backwards compatibility can not be provided. |
| 30 | +API = 5 |
0 commit comments