Skip to content

Commit 3394591

Browse files
Backport PR #879 on branch 7.x (Reconcile connection information) (#881)
Co-authored-by: Kevin Bates <[email protected]>
1 parent 763ff5f commit 3394591

File tree

3 files changed

+132
-25
lines changed

3 files changed

+132
-25
lines changed

jupyter_client/connect.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -158,21 +158,9 @@ def write_connection_file(
158158
cfg["signature_scheme"] = signature_scheme
159159
cfg["kernel_name"] = kernel_name
160160

161-
# Prevent over-writing a file that has already been written with the same
162-
# info. This is to prevent a race condition where the process has
163-
# already been launched but has not yet read the connection file.
164-
if os.path.exists(fname):
165-
with open(fname) as f:
166-
try:
167-
data = json.load(f)
168-
if data == cfg:
169-
return fname, cfg
170-
except Exception:
171-
pass
172-
173161
# Only ever write this file as user read/writeable
174162
# This would otherwise introduce a vulnerability as a file has secrets
175-
# which would let others execute arbitrarily code as you
163+
# which would let others execute arbitrary code as you
176164
with secure_write(fname) as f:
177165
f.write(json.dumps(cfg, indent=2))
178166

@@ -579,18 +567,70 @@ def load_connection_info(self, info: KernelConnectionInfo) -> None:
579567
if "signature_scheme" in info:
580568
self.session.signature_scheme = info["signature_scheme"]
581569

582-
def _force_connection_info(self, info: KernelConnectionInfo) -> None:
583-
"""Unconditionally loads connection info from a dict containing connection info.
570+
def _reconcile_connection_info(self, info: KernelConnectionInfo) -> None:
571+
"""Reconciles the connection information returned from the Provisioner.
584572
585-
Overwrites connection info-based attributes, regardless of their current values
586-
and writes this information to the connection file.
573+
Because some provisioners (like derivations of LocalProvisioner) may have already
574+
written the connection file, this method needs to ensure that, if the connection
575+
file exists, its contents match that of what was returned by the provisioner. If
576+
the file does exist and its contents do not match, a ValueError is raised.
577+
578+
If the file does not exist, the connection information in 'info' is loaded into the
579+
KernelManager and written to the file.
587580
"""
588-
# Reset current ports to 0 and indicate file has not been written to enable override
589-
self._connection_file_written = False
590-
for name in port_names:
591-
setattr(self, name, 0)
592-
self.load_connection_info(info)
593-
self.write_connection_file()
581+
# Prevent over-writing a file that has already been written with the same
582+
# info. This is to prevent a race condition where the process has
583+
# already been launched but has not yet read the connection file - as is
584+
# the case with LocalProvisioners.
585+
file_exists: bool = False
586+
if os.path.exists(self.connection_file):
587+
with open(self.connection_file) as f:
588+
file_info = json.load(f)
589+
# Prior to the following comparison, we need to adjust the value of "key" to
590+
# be bytes, otherwise the comparison below will fail.
591+
file_info["key"] = file_info["key"].encode()
592+
if not self._equal_connections(info, file_info):
593+
raise ValueError(
594+
"Connection file already exists and does not match "
595+
"the expected values returned from provisioner!"
596+
)
597+
file_exists = True
598+
599+
if not file_exists:
600+
# Load the connection info and write out file. Note, this does not necessarily
601+
# overwrite non-zero port values, so we'll validate afterward.
602+
self.load_connection_info(info)
603+
self.write_connection_file()
604+
605+
# Ensure what is in KernelManager is what we expect. This will catch issues if the file
606+
# already existed, yet it's contents differed from the KernelManager's (and provisioner).
607+
km_info = self.get_connection_info()
608+
if not self._equal_connections(info, km_info):
609+
raise ValueError(
610+
"KernelManager's connection information already exists and does not match "
611+
"the expected values returned from provisioner!"
612+
)
613+
614+
@staticmethod
615+
def _equal_connections(conn1: KernelConnectionInfo, conn2: KernelConnectionInfo) -> bool:
616+
"""Compares pertinent keys of connection info data. Returns True if equivalent, False otherwise."""
617+
618+
pertinent_keys = [
619+
"key",
620+
"ip",
621+
"stdin_port",
622+
"iopub_port",
623+
"shell_port",
624+
"control_port",
625+
"hb_port",
626+
"transport",
627+
"signature_scheme",
628+
]
629+
630+
for key in pertinent_keys:
631+
if conn1.get(key) != conn2.get(key):
632+
return False
633+
return True
594634

595635
# --------------------------------------------------------------------------
596636
# Creating connected sockets

jupyter_client/manager.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,9 @@ async def _async_launch_kernel(self, kernel_cmd: t.List[str], **kw: t.Any) -> No
310310
assert self.provisioner is not None
311311
connection_info = await self.provisioner.launch_kernel(kernel_cmd, **kw)
312312
assert self.provisioner.has_process
313-
# Provisioner provides the connection information. Load into kernel manager and write file.
314-
self._force_connection_info(connection_info)
313+
# Provisioner provides the connection information. Load into kernel manager
314+
# and write the connection file, if not already done.
315+
self._reconcile_connection_info(connection_info)
315316

316317
_launch_kernel = run_sync(_async_launch_kernel)
317318

jupyter_client/tests/test_connect.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
import os
66
from tempfile import TemporaryDirectory
77

8+
import pytest
89
from jupyter_core.application import JupyterApp
910
from jupyter_core.paths import jupyter_runtime_dir
1011

1112
from jupyter_client import connect
1213
from jupyter_client import KernelClient
14+
from jupyter_client import KernelManager
1315
from jupyter_client.consoleapp import JupyterConsoleApp
1416
from jupyter_client.session import Session
1517

@@ -235,3 +237,67 @@ def test_mixin_cleanup_random_ports():
235237
assert not os.path.exists(filename)
236238
for name in dc._random_port_names:
237239
assert getattr(dc, name) == 0
240+
241+
242+
param_values = [
243+
(True, True, None),
244+
(True, False, ValueError),
245+
(False, True, None),
246+
(False, False, ValueError),
247+
]
248+
249+
250+
@pytest.mark.parametrize("file_exists, km_matches, expected_exception", param_values)
251+
def test_reconcile_connection_info(file_exists, km_matches, expected_exception):
252+
253+
expected_info = sample_info
254+
mismatched_info = sample_info.copy()
255+
mismatched_info["key"] = b"def456"
256+
mismatched_info["shell_port"] = expected_info["shell_port"] + 42
257+
mismatched_info["control_port"] = expected_info["control_port"] + 42
258+
259+
with TemporaryDirectory() as connection_dir:
260+
261+
cf = os.path.join(connection_dir, "kernel.json")
262+
km = KernelManager()
263+
km.connection_file = cf
264+
265+
if file_exists:
266+
_, info = connect.write_connection_file(cf, **expected_info)
267+
info["key"] = info["key"].encode() # set 'key' back to bytes
268+
269+
if km_matches:
270+
# Let this be the case where the connection file exists, and the KM has matching
271+
# values prior to reconciliation. This is the LocalProvisioner case.
272+
provisioner_info = info
273+
km.load_connection_info(provisioner_info)
274+
else:
275+
# Let this be the case where the connection file exists, the KM has no values
276+
# prior to reconciliation, but the provisioner has returned different values
277+
# and a ValueError is expected.
278+
provisioner_info = mismatched_info
279+
else: # connection file does not exist
280+
if km_matches:
281+
# Let this be the case where the connection file does not exist, NOR does the KM
282+
# have any values of its own and reconciliation sets those values. This is the
283+
# non-LocalProvisioner case.
284+
provisioner_info = expected_info
285+
else:
286+
# Let this be the case where the connection file does not exist, yet the KM
287+
# has values that do not match those returned from the provisioner and a
288+
# ValueError is expected.
289+
km.load_connection_info(expected_info)
290+
provisioner_info = mismatched_info
291+
292+
if expected_exception is None:
293+
km._reconcile_connection_info(provisioner_info)
294+
km_info = km.get_connection_info()
295+
assert km._equal_connections(km_info, provisioner_info)
296+
else:
297+
with pytest.raises(expected_exception) as ex:
298+
km._reconcile_connection_info(provisioner_info)
299+
if file_exists:
300+
assert "Connection file already exists" in str(ex.value)
301+
else:
302+
assert "KernelManager's connection information already exists" in str(ex.value)
303+
assert km._equal_connections(km.get_connection_info(), provisioner_info) is False

0 commit comments

Comments
 (0)