Skip to content

Commit 2e36f8a

Browse files
committed
Fix nudge and activity issues when kernel ports change on restarts
1 parent 41e837f commit 2e36f8a

File tree

3 files changed

+166
-13
lines changed

3 files changed

+166
-13
lines changed

jupyter_server/services/kernels/handlers.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,13 @@ def open(self, kernel_id):
369369
buffer_info = km.get_buffer(kernel_id, self.session_key)
370370
if buffer_info and buffer_info['session_key'] == self.session_key:
371371
self.log.info("Restoring connection for %s", self.session_key)
372-
self.channels = buffer_info['channels']
372+
if km.ports_changed(kernel_id):
373+
# If the kernel's ports have changed (some restarts trigger this)
374+
# then reset the channels so nudge() is using the correct iopub channel
375+
self.create_stream()
376+
else:
377+
# The kernel's ports have not changed; use the channels captured in the buffer
378+
self.channels = buffer_info['channels']
373379

374380
connected = self.nudge()
375381

@@ -381,15 +387,14 @@ def replay(value):
381387
stream = self.channels[channel]
382388
self._on_zmq_reply(stream, msg_list)
383389

384-
385390
connected.add_done_callback(replay)
386391
else:
387392
try:
388393
self.create_stream()
389394
connected = self.nudge()
390395
except web.HTTPError as e:
391396
self.log.error("Error opening stream: %s", e)
392-
# WebSockets don't response to traditional error codes so we
397+
# WebSockets don't respond to traditional error codes so we
393398
# close the connection.
394399
for channel, stream in self.channels.items():
395400
if not stream.closed():

jupyter_server/services/kernels/kernelmanager.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def _default_kernel_manager_class(self):
4747

4848
_kernel_connections = Dict()
4949

50+
_kernel_ports = Dict()
51+
5052
_culler_callback = None
5153

5254
_initialized_culler = False
@@ -183,6 +185,7 @@ async def start_kernel(self, kernel_id=None, path=None, **kwargs):
183185
kwargs['cwd'] = self.cwd_for_path(path)
184186
kernel_id = await ensure_async(self.pinned_superclass.start_kernel(self, **kwargs))
185187
self._kernel_connections[kernel_id] = 0
188+
self._kernel_ports[kernel_id] = self._kernels[kernel_id].ports
186189
self.start_watching_activity(kernel_id)
187190
self.log.info("Kernel started: %s" % kernel_id)
188191
self.log.debug("Kernel args: %r" % kwargs)
@@ -208,6 +211,40 @@ async def start_kernel(self, kernel_id=None, path=None, **kwargs):
208211

209212
return kernel_id
210213

214+
def ports_changed(self, kernel_id):
215+
"""Used by ZMQChannelsHandler to determine how to coordinate nudge and replays.
216+
217+
Ports are captured when starting a kernel (via MappingKernelManager). Ports
218+
are considered changed (following restarts) if the referenced KernelManager
219+
is using a set of ports different from those captured at startup. If changes
220+
are detected, the captured set is updated and a value of True is returned.
221+
222+
NOTE: Use is exclusive to ZMQChannelsHandler because this object is a singleton
223+
instance while ZMQChannelsHandler instances are per WebSocket connection that
224+
can vary per kernel lifetime.
225+
"""
226+
changed_ports = self._get_changed_ports(kernel_id)
227+
if changed_ports:
228+
# If changed, update captured ports and return True, else return False.
229+
self.log.debug(f"Port change detected for kernel: {kernel_id}")
230+
self._kernel_ports[kernel_id] = changed_ports
231+
return True
232+
return False
233+
234+
def _get_changed_ports(self, kernel_id):
235+
"""Internal method to test if a kernel's ports have changed and, if so, return their values.
236+
237+
This method does NOT update the captured ports for the kernel as that can only be done
238+
by ZMQChannelsHandler, but instead returns the new list of ports if they are different
239+
than those captured at startup. This enables the ability to conditionally restart
240+
activity monitoring immediately following a kernel's restart (if ports have changed).
241+
"""
242+
# Get current ports and return comparison with ports captured at startup.
243+
km = self.get_kernel(kernel_id)
244+
if km.ports != self._kernel_ports[kernel_id]:
245+
return km.ports
246+
return None
247+
211248
def start_buffering(self, kernel_id, session_key, channels):
212249
"""Start buffering messages for a kernel
213250
@@ -300,10 +337,7 @@ def stop_buffering(self, kernel_id):
300337
def shutdown_kernel(self, kernel_id, now=False, restart=False):
301338
"""Shutdown a kernel by kernel_id"""
302339
self._check_kernel_id(kernel_id)
303-
kernel = self._kernels[kernel_id]
304-
if kernel._activity_stream:
305-
kernel._activity_stream.close()
306-
kernel._activity_stream = None
340+
self.stop_watching_activity(kernel_id)
307341
self.stop_buffering(kernel_id)
308342
self._kernel_connections.pop(kernel_id, None)
309343

@@ -319,6 +353,7 @@ def shutdown_kernel(self, kernel_id, now=False, restart=False):
319353
# method is synchronous. However, we'll keep the relative call orders the same from
320354
# a maintenance perspective.
321355
self._kernel_connections.pop(kernel_id, None)
356+
self._kernel_ports.pop(kernel_id, None)
322357

323358
async def restart_kernel(self, kernel_id, now=False):
324359
"""Restart a kernel by kernel_id"""
@@ -359,6 +394,10 @@ def on_restart_failed():
359394
channel.on_recv(on_reply)
360395
loop = IOLoop.current()
361396
timeout = loop.add_timeout(loop.time() + self.kernel_info_timeout, on_timeout)
397+
# Re-establish activity watching if ports have changed...
398+
if self._get_changed_ports(kernel_id) is not None:
399+
self.stop_watching_activity(kernel_id)
400+
self.start_watching_activity(kernel_id)
362401
return future
363402

364403
def notify_connect(self, kernel_id):
@@ -440,6 +479,13 @@ def record_activity(msg_list):
440479

441480
kernel._activity_stream.on_recv(record_activity)
442481

482+
def stop_watching_activity(self, kernel_id):
483+
"""Stop watching IOPub messages on a kernel for activity."""
484+
kernel = self._kernels[kernel_id]
485+
if kernel._activity_stream:
486+
kernel._activity_stream.close()
487+
kernel._activity_stream = None
488+
443489
def initialize_culler(self):
444490
"""Start idle culler if 'cull_idle_timeout' is greater than zero.
445491
@@ -511,10 +557,7 @@ def __init__(self, **kwargs):
511557
async def shutdown_kernel(self, kernel_id, now=False, restart=False):
512558
"""Shutdown a kernel by kernel_id"""
513559
self._check_kernel_id(kernel_id)
514-
kernel = self._kernels[kernel_id]
515-
if kernel._activity_stream:
516-
kernel._activity_stream.close()
517-
kernel._activity_stream = None
560+
self.stop_watching_activity(kernel_id)
518561
self.stop_buffering(kernel_id)
519562

520563
# Decrease the metric of number of kernels
@@ -526,4 +569,5 @@ async def shutdown_kernel(self, kernel_id, now=False, restart=False):
526569
# Finish shutting down the kernel before clearing state to avoid a race condition.
527570
ret = await self.pinned_superclass.shutdown_kernel(self, kernel_id, now=now, restart=restart)
528571
self._kernel_connections.pop(kernel_id, None)
572+
self._kernel_ports.pop(kernel_id, None)
529573
return ret

jupyter_server/tests/services/sessions/test_api.py

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,47 @@
1-
import sys
21
import time
32
import json
43
import shutil
54
import pytest
65

76
import tornado
87

8+
from jupyter_client.ioloop import AsyncIOLoopKernelManager
9+
910
from nbformat.v4 import new_notebook
1011
from nbformat import writes
12+
from traitlets import default
1113

1214
from ...utils import expected_http_error
15+
from jupyter_server.services.kernels.kernelmanager import AsyncMappingKernelManager
1316
from jupyter_server.utils import url_path_join
1417

1518

1619
j = lambda r: json.loads(r.body.decode())
1720

1821

19-
@pytest.fixture(params=["MappingKernelManager", "AsyncMappingKernelManager"])
22+
class NewPortsKernelManager(AsyncIOLoopKernelManager):
23+
24+
@default('cache_ports')
25+
def _default_cache_ports(self) -> bool:
26+
return False
27+
28+
async def restart_kernel(self, now: bool = False, newports: bool = True, **kw) -> None:
29+
self.log.debug(f"DEBUG**** calling super().restart_kernel with newports={newports}")
30+
return await super().restart_kernel(now=now, newports=newports, **kw)
31+
32+
33+
class NewPortsMappingKernelManager(AsyncMappingKernelManager):
34+
35+
@default('kernel_manager_class')
36+
def _default_kernel_manager_class(self):
37+
self.log.debug("NewPortsMappingKernelManager in _default_kernel_manager_class!")
38+
return "jupyter_server.tests.services.sessions.test_api.NewPortsKernelManager"
39+
40+
41+
@pytest.fixture(params=["MappingKernelManager", "AsyncMappingKernelManager", "NewPortsMappingKernelManager"])
2042
def jp_argv(request):
43+
if request.param == "NewPortsMappingKernelManager":
44+
return ["--ServerApp.kernel_manager_class=jupyter_server.tests.services.sessions.test_api." + request.param]
2145
return ["--ServerApp.kernel_manager_class=jupyter_server.services.kernels.kernelmanager." + request.param]
2246

2347

@@ -339,3 +363,83 @@ async def test_modify_kernel_id(session_client, jp_fetch):
339363

340364
# Need to find a better solution to this.
341365
await session_client.cleanup()
366+
367+
368+
async def test_restart_kernel(session_client, jp_base_url, jp_fetch, jp_ws_fetch):
369+
370+
# Create a session.
371+
resp = await session_client.create('foo/nb1.ipynb')
372+
assert resp.code == 201
373+
new_session = j(resp)
374+
assert 'id' in new_session
375+
assert new_session['path'] == 'foo/nb1.ipynb'
376+
assert new_session['type'] == 'notebook'
377+
assert resp.headers['Location'] == url_path_join(jp_base_url, '/api/sessions/', new_session['id'])
378+
379+
kid = new_session['kernel']['id']
380+
381+
# Get kernel info
382+
r = await jp_fetch(
383+
'api', 'kernels', kid,
384+
method='GET'
385+
)
386+
model = json.loads(r.body.decode())
387+
assert model['connections'] == 0
388+
389+
# Open a websocket connection.
390+
ws = await jp_ws_fetch(
391+
'api', 'kernels', kid, 'channels'
392+
)
393+
394+
# Test that it was opened.
395+
r = await jp_fetch(
396+
'api', 'kernels', kid,
397+
method='GET'
398+
)
399+
model = json.loads(r.body.decode())
400+
assert model['connections'] == 1
401+
402+
# Restart kernel
403+
r = await jp_fetch(
404+
'api', 'kernels', kid, 'restart',
405+
method='POST',
406+
allow_nonstandard_methods=True
407+
)
408+
restarted_kernel = json.loads(r.body.decode())
409+
assert restarted_kernel['id'] == kid
410+
411+
# Close/open websocket
412+
ws.close()
413+
# give it some time to close on the other side:
414+
for i in range(10):
415+
r = await jp_fetch(
416+
'api', 'kernels', kid,
417+
method='GET'
418+
)
419+
model = json.loads(r.body.decode())
420+
if model['connections'] > 0:
421+
time.sleep(0.1)
422+
else:
423+
break
424+
425+
r = await jp_fetch(
426+
'api', 'kernels', kid,
427+
method='GET'
428+
)
429+
model = json.loads(r.body.decode())
430+
assert model['connections'] == 0
431+
432+
# Open a websocket connection.
433+
await jp_ws_fetch(
434+
'api', 'kernels', kid, 'channels'
435+
)
436+
437+
r = await jp_fetch(
438+
'api', 'kernels', kid,
439+
method='GET'
440+
)
441+
model = json.loads(r.body.decode())
442+
assert model['connections'] == 1
443+
444+
# Need to find a better solution to this.
445+
await session_client.cleanup()

0 commit comments

Comments
 (0)