Skip to content

Commit 8e289b6

Browse files
committed
add load_connection_info
for loading connection_info from a dict, not just files.
1 parent 112095f commit 8e289b6

File tree

2 files changed

+59
-16
lines changed

2 files changed

+59
-16
lines changed

jupyter_client/connect.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323
from traitlets.config import LoggingConfigurable
2424
from .localinterfaces import localhost
2525
from ipython_genutils.path import filefind
26-
from ipython_genutils.py3compat import (str_to_bytes, bytes_to_str, cast_bytes_py2,
27-
string_types)
26+
from ipython_genutils.py3compat import (
27+
bytes_to_str, cast_bytes, cast_bytes_py2, string_types,
28+
)
2829
from traitlets import (
2930
Bool, Integer, Unicode, CaselessStrEnum, Instance, Type,
3031
)
@@ -411,23 +412,46 @@ def write_connection_file(self):
411412

412413
self._connection_file_written = True
413414

414-
def load_connection_file(self):
415-
"""Load connection info from JSON dict in self.connection_file."""
416-
self.log.debug(u"Loading connection file %s", self.connection_file)
417-
with open(self.connection_file) as f:
418-
cfg = json.load(f)
419-
self.transport = cfg.get('transport', self.transport)
420-
self.ip = cfg.get('ip', self._ip_default())
415+
def load_connection_file(self, connection_file=None):
416+
"""Load connection info from JSON dict in self.connection_file.
417+
418+
Parameters
419+
----------
420+
connection_file: unicode, optional
421+
Path to connection file to load.
422+
If unspecified, use self.connection_file
423+
"""
424+
if connection_file is None:
425+
connection_file = self.connection_file
426+
self.log.debug(u"Loading connection file %s", connection_file)
427+
with open(connection_file) as f:
428+
info = json.load(f)
429+
self.load_connection_info(info)
430+
431+
def load_connection_info(self, info):
432+
"""Load connection info from a dict containing connection info.
433+
434+
Typically this data comes from a connection file
435+
and is called by load_connection_file.
436+
437+
Parameters
438+
----------
439+
info: dict
440+
Dictionary containing connection_info.
441+
See the connection_file spec for details.
442+
"""
443+
self.transport = info.get('transport', self.transport)
444+
self.ip = info.get('ip', self._ip_default())
421445

422446
for name in port_names:
423-
if getattr(self, name) == 0 and name in cfg:
447+
if getattr(self, name) == 0 and name in info:
424448
# not overridden by config or cl_args
425-
setattr(self, name, cfg[name])
449+
setattr(self, name, info[name])
426450

427-
if 'key' in cfg:
428-
self.session.key = str_to_bytes(cfg['key'])
429-
if 'signature_scheme' in cfg:
430-
self.session.signature_scheme = cfg['signature_scheme']
451+
if 'key' in info:
452+
self.session.key = cast_bytes(info['key'])
453+
if 'signature_scheme' in info:
454+
self.session.signature_scheme = info['signature_scheme']
431455

432456
#--------------------------------------------------------------------------
433457
# Creating connected sockets

jupyter_client/tests/test_connect.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from jupyter_core.application import JupyterApp
1313
from ipython_genutils.tempdir import TemporaryDirectory, TemporaryWorkingDirectory
1414
from ipython_genutils.py3compat import str_to_bytes
15-
from jupyter_client import connect
15+
from jupyter_client import connect, KernelClient
1616
from jupyter_client.consoleapp import JupyterConsoleApp
1717
from jupyter_client.session import Session
1818

@@ -92,6 +92,25 @@ def test_app_load_connection_file():
9292
nt.assert_equal(value, expected, "app.%s = %s != %s" % (attr, value, expected))
9393

9494

95+
def test_load_connection_info():
96+
client = KernelClient()
97+
info = {
98+
'control_port': 53702,
99+
'hb_port': 53705,
100+
'iopub_port': 53703,
101+
'ip': '0.0.0.0',
102+
'key': 'secret',
103+
'shell_port': 53700,
104+
'signature_scheme': 'hmac-sha256',
105+
'stdin_port': 53701,
106+
'transport': 'tcp',
107+
}
108+
client.load_connection_info(info)
109+
assert client.control_port == info['control_port']
110+
assert client.session.key.decode('ascii') == info['key']
111+
assert client.ip == info['ip']
112+
113+
95114
def test_find_connection_file():
96115
cfg = Config()
97116
with TemporaryDirectory() as d:

0 commit comments

Comments
 (0)