Skip to content

Commit c908539

Browse files
network: move code out of __init__file
* Some code in the __init__ file was creating a circular dependency issue. * Move this code into its own modules, this may have efficiecncy benefits for other network modules when the code in __init__ is not required.
1 parent 6ab1ede commit c908539

File tree

11 files changed

+348
-292
lines changed

11 files changed

+348
-292
lines changed

cylc/flow/network/__init__.py

Lines changed: 12 additions & 273 deletions
Original file line numberDiff line numberDiff line change
@@ -13,279 +13,18 @@
1313
#
1414
# You should have received a copy of the GNU General Public License
1515
# along with this program. If not, see <http://www.gnu.org/licenses/>.
16-
"""Package for network interfaces to Cylc scheduler objects."""
1716

18-
import asyncio
19-
import getpass
20-
import json
21-
from typing import Optional, Tuple
17+
"""Cylc networking code.
2218
23-
import zmq
24-
import zmq.asyncio
25-
import zmq.auth
19+
Contains:
20+
* Server code (hosted by the scheduler process).
21+
* Client implementations (used to communicate with the scheduler).
22+
* Workflow scanning logic.
23+
* Schema and interface definitions.
24+
"""
2625

27-
from cylc.flow import LOG
28-
from cylc.flow.exceptions import (
29-
ClientError,
30-
CylcError,
31-
CylcVersionError,
32-
ServiceFileError,
33-
WorkflowStopped
34-
)
35-
from cylc.flow.hostuserutil import get_fqdn_by_host
36-
from cylc.flow.workflow_files import (
37-
ContactFileFields,
38-
KeyType,
39-
KeyOwner,
40-
KeyInfo,
41-
load_contact_file,
42-
get_workflow_srv_dir
43-
)
44-
45-
API = 5 # cylc API version
46-
MSG_TIMEOUT = "TIMEOUT"
47-
48-
49-
def encode_(message):
50-
"""Convert the structure holding a message field from JSON to a string."""
51-
try:
52-
return json.dumps(message)
53-
except TypeError as exc:
54-
return json.dumps({'errors': [{'message': str(exc)}]})
55-
56-
57-
def decode_(message):
58-
"""Convert an encoded message string to JSON with an added 'user' field."""
59-
msg = json.loads(message)
60-
msg['user'] = getpass.getuser() # assume this is the user
61-
return msg
62-
63-
64-
def get_location(workflow: str) -> Tuple[str, int, int]:
65-
"""Extract host and port from a workflow's contact file.
66-
67-
NB: if it fails to load the workflow contact file, it will exit.
68-
69-
Args:
70-
workflow: workflow ID
71-
Returns:
72-
Tuple (host name, port number, publish port number)
73-
Raises:
74-
WorkflowStopped: if the workflow is not running.
75-
CylcVersionError: if target is a Cylc 7 (or earlier) workflow.
76-
"""
77-
try:
78-
contact = load_contact_file(workflow)
79-
except (IOError, ValueError, ServiceFileError):
80-
# Contact file does not exist or corrupted, workflow should be dead
81-
raise WorkflowStopped(workflow)
82-
83-
host = contact[ContactFileFields.HOST]
84-
host = get_fqdn_by_host(host)
85-
port = int(contact[ContactFileFields.PORT])
86-
if ContactFileFields.PUBLISH_PORT in contact:
87-
pub_port = int(contact[ContactFileFields.PUBLISH_PORT])
88-
else:
89-
version = contact.get('CYLC_VERSION', None)
90-
raise CylcVersionError(version=version)
91-
return host, port, pub_port
92-
93-
94-
class ZMQSocketBase:
95-
"""Initiate the ZMQ socket bind for specified pattern.
96-
97-
NOTE: Security to be provided via zmq.auth (see PR #3359).
98-
99-
Args:
100-
pattern (enum): ZeroMQ message pattern (zmq.PATTERN).
101-
102-
context (object, optional): instantiated ZeroMQ context, defaults
103-
to zmq.asyncio.Context().
104-
105-
This class is designed to be inherited by REP Server (REQ/REP)
106-
and by PUB Publisher (PUB/SUB), as the start-up logic is similar.
107-
108-
109-
To tailor this class overwrite it's method on inheritance.
110-
111-
"""
112-
113-
def __init__(
114-
self,
115-
pattern,
116-
workflow: str,
117-
bind: bool = False,
118-
context: Optional[zmq.Context] = None,
119-
):
120-
self.bind = bind
121-
if context is None:
122-
self.context: zmq.Context = zmq.asyncio.Context()
123-
else:
124-
self.context = context
125-
self.pattern = pattern
126-
self.workflow = workflow
127-
self.host: Optional[str] = None
128-
self.port: Optional[int] = None
129-
self.socket: Optional[zmq.Socket] = None
130-
self.loop: Optional[asyncio.AbstractEventLoop] = None
131-
self.stopping = False
132-
133-
def start(self, *args, **kwargs):
134-
"""Create the async loop, and bind socket."""
135-
# set asyncio loop
136-
try:
137-
self.loop = asyncio.get_running_loop()
138-
except RuntimeError:
139-
self.loop = asyncio.new_event_loop()
140-
asyncio.set_event_loop(self.loop)
141-
142-
if self.bind:
143-
self._socket_bind(*args, **kwargs)
144-
else:
145-
self._socket_connect(*args, **kwargs)
146-
147-
# initiate bespoke items
148-
self._bespoke_start()
149-
150-
# Keeping srv_prv_key_loc as optional arg so as to not break interface
151-
def _socket_bind(self, min_port, max_port, srv_prv_key_loc=None):
152-
"""Bind socket.
153-
154-
Will use a port range provided to select random ports.
155-
156-
"""
157-
if srv_prv_key_loc is None:
158-
# Create new KeyInfo object for the server private key
159-
workflow_srv_dir = get_workflow_srv_dir(self.workflow)
160-
srv_prv_key_info = KeyInfo(
161-
KeyType.PRIVATE,
162-
KeyOwner.SERVER,
163-
workflow_srv_dir=workflow_srv_dir)
164-
else:
165-
srv_prv_key_info = KeyInfo(
166-
KeyType.PRIVATE,
167-
KeyOwner.SERVER,
168-
full_key_path=srv_prv_key_loc)
169-
170-
# create socket
171-
self.socket = self.context.socket(self.pattern)
172-
self._socket_options()
173-
174-
try:
175-
server_public_key, server_private_key = zmq.auth.load_certificate(
176-
srv_prv_key_info.full_key_path)
177-
except ValueError:
178-
raise ServiceFileError(
179-
f"Failed to find server's public "
180-
f"key in "
181-
f"{srv_prv_key_info.full_key_path}."
182-
)
183-
except OSError:
184-
raise ServiceFileError(
185-
f"IO error opening server's private "
186-
f"key from "
187-
f"{srv_prv_key_info.full_key_path}."
188-
)
189-
if server_private_key is None: # this can't be caught by exception
190-
raise ServiceFileError(
191-
f"Failed to find server's private "
192-
f"key in "
193-
f"{srv_prv_key_info.full_key_path}."
194-
)
195-
self.socket.curve_publickey = server_public_key
196-
self.socket.curve_secretkey = server_private_key
197-
self.socket.curve_server = True
198-
199-
try:
200-
if min_port == max_port:
201-
self.port = min_port
202-
self.socket.bind(f'tcp://*:{min_port}')
203-
else:
204-
self.port = self.socket.bind_to_random_port(
205-
'tcp://*', min_port, max_port)
206-
except (zmq.error.ZMQError, zmq.error.ZMQBindError) as exc:
207-
raise CylcError(f'could not start Cylc ZMQ server: {exc}')
208-
209-
# Keeping srv_public_key_loc as optional arg so as to not break interface
210-
def _socket_connect(self, host, port, srv_public_key_loc=None):
211-
"""Connect socket to stub."""
212-
workflow_srv_dir = get_workflow_srv_dir(self.workflow)
213-
if srv_public_key_loc is None:
214-
# Create new KeyInfo object for the server public key
215-
srv_pub_key_info = KeyInfo(
216-
KeyType.PUBLIC,
217-
KeyOwner.SERVER,
218-
workflow_srv_dir=workflow_srv_dir)
219-
220-
else:
221-
srv_pub_key_info = KeyInfo(
222-
KeyType.PUBLIC,
223-
KeyOwner.SERVER,
224-
full_key_path=srv_public_key_loc)
225-
226-
self.host = host
227-
self.port = port
228-
self.socket = self.context.socket(self.pattern)
229-
self._socket_options()
230-
231-
client_priv_key_info = KeyInfo(
232-
KeyType.PRIVATE,
233-
KeyOwner.CLIENT,
234-
workflow_srv_dir=workflow_srv_dir)
235-
error_msg = "Failed to find user's private key, so cannot connect."
236-
try:
237-
client_public_key, client_priv_key = zmq.auth.load_certificate(
238-
client_priv_key_info.full_key_path)
239-
except (OSError, ValueError):
240-
raise ClientError(error_msg)
241-
if client_priv_key is None: # this can't be caught by exception
242-
raise ClientError(error_msg)
243-
self.socket.curve_publickey = client_public_key
244-
self.socket.curve_secretkey = client_priv_key
245-
246-
# A client can only connect to the server if it knows its public key,
247-
# so we grab this from the location it was created on the filesystem:
248-
try:
249-
# 'load_certificate' will try to load both public & private keys
250-
# from a provided file but will return None, not throw an error,
251-
# for the latter item if not there (as for all public key files)
252-
# so it is OK to use; there is no method to load only the
253-
# public key.
254-
server_public_key = zmq.auth.load_certificate(
255-
srv_pub_key_info.full_key_path)[0]
256-
self.socket.curve_serverkey = server_public_key
257-
except (OSError, ValueError): # ValueError raised w/ no public key
258-
raise ClientError(
259-
"Failed to load the workflow's public key, so cannot connect.")
260-
261-
self.socket.connect(f'tcp://{host}:{port}')
262-
263-
def _socket_options(self):
264-
"""Set socket options.
265-
266-
i.e. self.socket.sndhwm
267-
"""
268-
self.socket.sndhwm = 10000
269-
270-
def _bespoke_start(self):
271-
"""Initiate bespoke items at start."""
272-
self.stopping = False
273-
274-
def stop(self, stop_loop=True):
275-
"""Stop the server.
276-
277-
Args:
278-
stop_loop (Boolean): Stop running IOLoop.
279-
280-
"""
281-
self._bespoke_stop()
282-
if stop_loop and self.loop and self.loop.is_running():
283-
self.loop.stop()
284-
if self.socket and not self.socket.closed:
285-
self.socket.close()
286-
LOG.debug('...stopped')
287-
288-
def _bespoke_stop(self):
289-
"""Bespoke stop items."""
290-
LOG.debug('stopping zmq socket...')
291-
self.stopping = True
26+
# Cylc API version.
27+
# This is the Cylc protocol version number that determines whether a client can
28+
# communicate with a server. This should be changed when breaking changes are
29+
# made for which backwards compatibility can not be provided.
30+
API = 5

0 commit comments

Comments
 (0)