| 
13 | 13 | #  | 
14 | 14 | # You should have received a copy of the GNU General Public License  | 
15 | 15 | # along with this program.  If not, see <http://www.gnu.org/licenses/>.  | 
16 |  | -"""Package for network interfaces to Cylc scheduler objects."""  | 
17 | 16 | 
 
  | 
18 |  | -import asyncio  | 
19 |  | -import getpass  | 
20 |  | -import json  | 
21 |  | -from typing import Optional, Tuple  | 
 | 17 | +"""Cylc networking code.  | 
22 | 18 | 
  | 
23 |  | -import zmq  | 
24 |  | -import zmq.asyncio  | 
25 |  | -import zmq.auth  | 
 | 19 | +Contains:  | 
 | 20 | +* Server code (hosted by the scheduler process).  | 
 | 21 | +* Client implementations (used to communicate with the scheduler).  | 
 | 22 | +* Workflow scanning logic.  | 
 | 23 | +* Schema and interface definitions.  | 
 | 24 | +"""  | 
26 | 25 | 
 
  | 
27 |  | -from cylc.flow import LOG  | 
28 |  | -from cylc.flow.exceptions import (  | 
29 |  | -    ClientError,  | 
30 |  | -    CylcError,  | 
31 |  | -    CylcVersionError,  | 
32 |  | -    ServiceFileError,  | 
33 |  | -    WorkflowStopped  | 
34 |  | -)  | 
35 |  | -from cylc.flow.hostuserutil import get_fqdn_by_host  | 
36 |  | -from cylc.flow.workflow_files import (  | 
37 |  | -    ContactFileFields,  | 
38 |  | -    KeyType,  | 
39 |  | -    KeyOwner,  | 
40 |  | -    KeyInfo,  | 
41 |  | -    load_contact_file,  | 
42 |  | -    get_workflow_srv_dir  | 
43 |  | -)  | 
44 |  | - | 
45 |  | -API = 5  # cylc API version  | 
46 |  | -MSG_TIMEOUT = "TIMEOUT"  | 
47 |  | - | 
48 |  | - | 
49 |  | -def encode_(message):  | 
50 |  | -    """Convert the structure holding a message field from JSON to a string."""  | 
51 |  | -    try:  | 
52 |  | -        return json.dumps(message)  | 
53 |  | -    except TypeError as exc:  | 
54 |  | -        return json.dumps({'errors': [{'message': str(exc)}]})  | 
55 |  | - | 
56 |  | - | 
57 |  | -def decode_(message):  | 
58 |  | -    """Convert an encoded message string to JSON with an added 'user' field."""  | 
59 |  | -    msg = json.loads(message)  | 
60 |  | -    msg['user'] = getpass.getuser()  # assume this is the user  | 
61 |  | -    return msg  | 
62 |  | - | 
63 |  | - | 
64 |  | -def get_location(workflow: str) -> Tuple[str, int, int]:  | 
65 |  | -    """Extract host and port from a workflow's contact file.  | 
66 |  | -
  | 
67 |  | -    NB: if it fails to load the workflow contact file, it will exit.  | 
68 |  | -
  | 
69 |  | -    Args:  | 
70 |  | -        workflow: workflow ID  | 
71 |  | -    Returns:  | 
72 |  | -        Tuple (host name, port number, publish port number)  | 
73 |  | -    Raises:  | 
74 |  | -        WorkflowStopped: if the workflow is not running.  | 
75 |  | -        CylcVersionError: if target is a Cylc 7 (or earlier) workflow.  | 
76 |  | -    """  | 
77 |  | -    try:  | 
78 |  | -        contact = load_contact_file(workflow)  | 
79 |  | -    except (IOError, ValueError, ServiceFileError):  | 
80 |  | -        # Contact file does not exist or corrupted, workflow should be dead  | 
81 |  | -        raise WorkflowStopped(workflow)  | 
82 |  | - | 
83 |  | -    host = contact[ContactFileFields.HOST]  | 
84 |  | -    host = get_fqdn_by_host(host)  | 
85 |  | -    port = int(contact[ContactFileFields.PORT])  | 
86 |  | -    if ContactFileFields.PUBLISH_PORT in contact:  | 
87 |  | -        pub_port = int(contact[ContactFileFields.PUBLISH_PORT])  | 
88 |  | -    else:  | 
89 |  | -        version = contact.get('CYLC_VERSION', None)  | 
90 |  | -        raise CylcVersionError(version=version)  | 
91 |  | -    return host, port, pub_port  | 
92 |  | - | 
93 |  | - | 
94 |  | -class ZMQSocketBase:  | 
95 |  | -    """Initiate the ZMQ socket bind for specified pattern.  | 
96 |  | -
  | 
97 |  | -    NOTE: Security to be provided via zmq.auth (see PR #3359).  | 
98 |  | -
  | 
99 |  | -    Args:  | 
100 |  | -        pattern (enum): ZeroMQ message pattern (zmq.PATTERN).  | 
101 |  | -
  | 
102 |  | -        context (object, optional): instantiated ZeroMQ context, defaults  | 
103 |  | -            to zmq.asyncio.Context().  | 
104 |  | -
  | 
105 |  | -    This class is designed to be inherited by REP Server (REQ/REP)  | 
106 |  | -    and by PUB Publisher (PUB/SUB), as the start-up logic is similar.  | 
107 |  | -
  | 
108 |  | -
  | 
109 |  | -    To tailor this class overwrite it's method on inheritance.  | 
110 |  | -
  | 
111 |  | -    """  | 
112 |  | - | 
113 |  | -    def __init__(  | 
114 |  | -        self,  | 
115 |  | -        pattern,  | 
116 |  | -        workflow: str,  | 
117 |  | -        bind: bool = False,  | 
118 |  | -        context: Optional[zmq.Context] = None,  | 
119 |  | -    ):  | 
120 |  | -        self.bind = bind  | 
121 |  | -        if context is None:  | 
122 |  | -            self.context: zmq.Context = zmq.asyncio.Context()  | 
123 |  | -        else:  | 
124 |  | -            self.context = context  | 
125 |  | -        self.pattern = pattern  | 
126 |  | -        self.workflow = workflow  | 
127 |  | -        self.host: Optional[str] = None  | 
128 |  | -        self.port: Optional[int] = None  | 
129 |  | -        self.socket: Optional[zmq.Socket] = None  | 
130 |  | -        self.loop: Optional[asyncio.AbstractEventLoop] = None  | 
131 |  | -        self.stopping = False  | 
132 |  | - | 
133 |  | -    def start(self, *args, **kwargs):  | 
134 |  | -        """Create the async loop, and bind socket."""  | 
135 |  | -        # set asyncio loop  | 
136 |  | -        try:  | 
137 |  | -            self.loop = asyncio.get_running_loop()  | 
138 |  | -        except RuntimeError:  | 
139 |  | -            self.loop = asyncio.new_event_loop()  | 
140 |  | -            asyncio.set_event_loop(self.loop)  | 
141 |  | - | 
142 |  | -        if self.bind:  | 
143 |  | -            self._socket_bind(*args, **kwargs)  | 
144 |  | -        else:  | 
145 |  | -            self._socket_connect(*args, **kwargs)  | 
146 |  | - | 
147 |  | -        # initiate bespoke items  | 
148 |  | -        self._bespoke_start()  | 
149 |  | - | 
150 |  | -    # Keeping srv_prv_key_loc as optional arg so as to not break interface  | 
151 |  | -    def _socket_bind(self, min_port, max_port, srv_prv_key_loc=None):  | 
152 |  | -        """Bind socket.  | 
153 |  | -
  | 
154 |  | -        Will use a port range provided to select random ports.  | 
155 |  | -
  | 
156 |  | -        """  | 
157 |  | -        if srv_prv_key_loc is None:  | 
158 |  | -            # Create new KeyInfo object for the server private key  | 
159 |  | -            workflow_srv_dir = get_workflow_srv_dir(self.workflow)  | 
160 |  | -            srv_prv_key_info = KeyInfo(  | 
161 |  | -                KeyType.PRIVATE,  | 
162 |  | -                KeyOwner.SERVER,  | 
163 |  | -                workflow_srv_dir=workflow_srv_dir)  | 
164 |  | -        else:  | 
165 |  | -            srv_prv_key_info = KeyInfo(  | 
166 |  | -                KeyType.PRIVATE,  | 
167 |  | -                KeyOwner.SERVER,  | 
168 |  | -                full_key_path=srv_prv_key_loc)  | 
169 |  | - | 
170 |  | -        # create socket  | 
171 |  | -        self.socket = self.context.socket(self.pattern)  | 
172 |  | -        self._socket_options()  | 
173 |  | - | 
174 |  | -        try:  | 
175 |  | -            server_public_key, server_private_key = zmq.auth.load_certificate(  | 
176 |  | -                srv_prv_key_info.full_key_path)  | 
177 |  | -        except ValueError:  | 
178 |  | -            raise ServiceFileError(  | 
179 |  | -                f"Failed to find server's public "  | 
180 |  | -                f"key in "  | 
181 |  | -                f"{srv_prv_key_info.full_key_path}."  | 
182 |  | -            )  | 
183 |  | -        except OSError:  | 
184 |  | -            raise ServiceFileError(  | 
185 |  | -                f"IO error opening server's private "  | 
186 |  | -                f"key from "  | 
187 |  | -                f"{srv_prv_key_info.full_key_path}."  | 
188 |  | -            )  | 
189 |  | -        if server_private_key is None:  # this can't be caught by exception  | 
190 |  | -            raise ServiceFileError(  | 
191 |  | -                f"Failed to find server's private "  | 
192 |  | -                f"key in "  | 
193 |  | -                f"{srv_prv_key_info.full_key_path}."  | 
194 |  | -            )  | 
195 |  | -        self.socket.curve_publickey = server_public_key  | 
196 |  | -        self.socket.curve_secretkey = server_private_key  | 
197 |  | -        self.socket.curve_server = True  | 
198 |  | - | 
199 |  | -        try:  | 
200 |  | -            if min_port == max_port:  | 
201 |  | -                self.port = min_port  | 
202 |  | -                self.socket.bind(f'tcp://*:{min_port}')  | 
203 |  | -            else:  | 
204 |  | -                self.port = self.socket.bind_to_random_port(  | 
205 |  | -                    'tcp://*', min_port, max_port)  | 
206 |  | -        except (zmq.error.ZMQError, zmq.error.ZMQBindError) as exc:  | 
207 |  | -            raise CylcError(f'could not start Cylc ZMQ server: {exc}')  | 
208 |  | - | 
209 |  | -    # Keeping srv_public_key_loc as optional arg so as to not break interface  | 
210 |  | -    def _socket_connect(self, host, port, srv_public_key_loc=None):  | 
211 |  | -        """Connect socket to stub."""  | 
212 |  | -        workflow_srv_dir = get_workflow_srv_dir(self.workflow)  | 
213 |  | -        if srv_public_key_loc is None:  | 
214 |  | -            # Create new KeyInfo object for the server public key  | 
215 |  | -            srv_pub_key_info = KeyInfo(  | 
216 |  | -                KeyType.PUBLIC,  | 
217 |  | -                KeyOwner.SERVER,  | 
218 |  | -                workflow_srv_dir=workflow_srv_dir)  | 
219 |  | - | 
220 |  | -        else:  | 
221 |  | -            srv_pub_key_info = KeyInfo(  | 
222 |  | -                KeyType.PUBLIC,  | 
223 |  | -                KeyOwner.SERVER,  | 
224 |  | -                full_key_path=srv_public_key_loc)  | 
225 |  | - | 
226 |  | -        self.host = host  | 
227 |  | -        self.port = port  | 
228 |  | -        self.socket = self.context.socket(self.pattern)  | 
229 |  | -        self._socket_options()  | 
230 |  | - | 
231 |  | -        client_priv_key_info = KeyInfo(  | 
232 |  | -            KeyType.PRIVATE,  | 
233 |  | -            KeyOwner.CLIENT,  | 
234 |  | -            workflow_srv_dir=workflow_srv_dir)  | 
235 |  | -        error_msg = "Failed to find user's private key, so cannot connect."  | 
236 |  | -        try:  | 
237 |  | -            client_public_key, client_priv_key = zmq.auth.load_certificate(  | 
238 |  | -                client_priv_key_info.full_key_path)  | 
239 |  | -        except (OSError, ValueError):  | 
240 |  | -            raise ClientError(error_msg)  | 
241 |  | -        if client_priv_key is None:  # this can't be caught by exception  | 
242 |  | -            raise ClientError(error_msg)  | 
243 |  | -        self.socket.curve_publickey = client_public_key  | 
244 |  | -        self.socket.curve_secretkey = client_priv_key  | 
245 |  | - | 
246 |  | -        # A client can only connect to the server if it knows its public key,  | 
247 |  | -        # so we grab this from the location it was created on the filesystem:  | 
248 |  | -        try:  | 
249 |  | -            # 'load_certificate' will try to load both public & private keys  | 
250 |  | -            # from a provided file but will return None, not throw an error,  | 
251 |  | -            # for the latter item if not there (as for all public key files)  | 
252 |  | -            # so it is OK to use; there is no method to load only the  | 
253 |  | -            # public key.  | 
254 |  | -            server_public_key = zmq.auth.load_certificate(  | 
255 |  | -                srv_pub_key_info.full_key_path)[0]  | 
256 |  | -            self.socket.curve_serverkey = server_public_key  | 
257 |  | -        except (OSError, ValueError):  # ValueError raised w/ no public key  | 
258 |  | -            raise ClientError(  | 
259 |  | -                "Failed to load the workflow's public key, so cannot connect.")  | 
260 |  | - | 
261 |  | -        self.socket.connect(f'tcp://{host}:{port}')  | 
262 |  | - | 
263 |  | -    def _socket_options(self):  | 
264 |  | -        """Set socket options.  | 
265 |  | -
  | 
266 |  | -        i.e. self.socket.sndhwm  | 
267 |  | -        """  | 
268 |  | -        self.socket.sndhwm = 10000  | 
269 |  | - | 
270 |  | -    def _bespoke_start(self):  | 
271 |  | -        """Initiate bespoke items at start."""  | 
272 |  | -        self.stopping = False  | 
273 |  | - | 
274 |  | -    def stop(self, stop_loop=True):  | 
275 |  | -        """Stop the server.  | 
276 |  | -
  | 
277 |  | -        Args:  | 
278 |  | -            stop_loop (Boolean): Stop running IOLoop.  | 
279 |  | -
  | 
280 |  | -        """  | 
281 |  | -        self._bespoke_stop()  | 
282 |  | -        if stop_loop and self.loop and self.loop.is_running():  | 
283 |  | -            self.loop.stop()  | 
284 |  | -        if self.socket and not self.socket.closed:  | 
285 |  | -            self.socket.close()  | 
286 |  | -        LOG.debug('...stopped')  | 
287 |  | - | 
288 |  | -    def _bespoke_stop(self):  | 
289 |  | -        """Bespoke stop items."""  | 
290 |  | -        LOG.debug('stopping zmq socket...')  | 
291 |  | -        self.stopping = True  | 
 | 26 | +# Cylc API version.  | 
 | 27 | +# This is the Cylc protocol version number that determines whether a client can  | 
 | 28 | +# communicate with a server. This should be changed when breaking changes are  | 
 | 29 | +# made for which backwards compatibility can not be provided.  | 
 | 30 | +API = 5  | 
0 commit comments