Skip to content

Commit 195ed51

Browse files
stop hook for extensions (#526)
* stop hook for extensions * closes #241 * call a stop_extension method on server shutdown if present * fix typo * make extension stop hooks async * extension stop hook tests * extension stop hooks feedback * run_sync * extension stop hooks extension_apps property
1 parent f7290dc commit 195ed51

File tree

7 files changed

+156
-29
lines changed

7 files changed

+156
-29
lines changed

docs/source/developers/extensions.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,19 @@ The basic structure of an ExtensionApp is shown below:
156156
...
157157
# Change the jinja templating environment
158158
159+
async def stop_extension(self):
160+
...
161+
# Perform any required shut down steps
162+
159163
160164
The ``ExtensionApp`` uses the following methods and properties to connect your extension to the Jupyter server. You do not need to define a ``_load_jupyter_server_extension`` function for these apps. Instead, overwrite the pieces below to add your custom settings, handlers and templates:
161165

162166
Methods
163167

164-
* ``initialize_setting()``: adds custom settings to the Tornado Web Application.
168+
* ``initialize_settings()``: adds custom settings to the Tornado Web Application.
165169
* ``initialize_handlers()``: appends handlers to the Tornado Web Application.
166170
* ``initialize_templates()``: initialize the templating engine (e.g. jinja2) for your frontend.
171+
* ``stop_extension()``: called on server shut down.
167172

168173
Properties
169174

jupyter_server/extension/application.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,9 @@ def start(self):
420420
# Start the server.
421421
self.serverapp.start()
422422

423+
async def stop_extension(self):
424+
"""Cleanup any resources managed by this extension."""
425+
423426
def stop(self):
424427
"""Stop the underlying Jupyter server.
425428
"""

jupyter_server/extension/manager.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import sys
33
import traceback
44

5+
from tornado.gen import multi
6+
57
from traitlets.config import LoggingConfigurable
68

79
from traitlets import (
@@ -230,15 +232,17 @@ def link_point(self, point_name, serverapp):
230232

231233
def load_point(self, point_name, serverapp):
232234
point = self.extension_points[point_name]
233-
point.load(serverapp)
235+
return point.load(serverapp)
234236

235237
def link_all_points(self, serverapp):
236238
for point_name in self.extension_points:
237239
self.link_point(point_name, serverapp)
238240

239241
def load_all_points(self, serverapp):
240-
for point_name in self.extension_points:
242+
return [
241243
self.load_point(point_name, serverapp)
244+
for point_name in self.extension_points
245+
]
242246

243247

244248
class ExtensionManager(LoggingConfigurable):
@@ -290,12 +294,26 @@ def sorted_extensions(self):
290294
"""
291295
)
292296

297+
@property
298+
def extension_apps(self):
299+
"""Return mapping of extension names and sets of ExtensionApp objects.
300+
"""
301+
return {
302+
name: {
303+
point.app
304+
for point in extension.extension_points.values()
305+
if point.app
306+
}
307+
for name, extension in self.extensions.items()
308+
}
309+
293310
@property
294311
def extension_points(self):
295-
extensions = self.extensions
312+
"""Return mapping of extension point names and ExtensionPoint objects.
313+
"""
296314
return {
297315
name: point
298-
for value in extensions.values()
316+
for value in self.extensions.values()
299317
for name, point in value.extension_points.items()
300318
}
301319

@@ -341,13 +359,22 @@ def link_extension(self, name, serverapp):
341359

342360
def load_extension(self, name, serverapp):
343361
extension = self.extensions.get(name)
362+
344363
if extension.enabled:
345364
try:
346-
extension.load_all_points(serverapp)
347-
self.log.info("{name} | extension was successfully loaded.".format(name=name))
365+
points = extension.load_all_points(serverapp)
348366
except Exception as e:
349367
self.log.debug("".join(traceback.format_exception(*sys.exc_info())))
350368
self.log.warning("{name} | extension failed loading with message: {error}".format(name=name,error=str(e)))
369+
else:
370+
self.log.info("{name} | extension was successfully loaded.".format(name=name))
371+
372+
async def stop_extension(self, name, apps):
373+
"""Call the shutdown hooks in the specified apps."""
374+
for app in apps:
375+
self.log.debug('{} | extension app "{}" stopping'.format(name, app.name))
376+
await app.stop_extension()
377+
self.log.debug('{} | extension app "{}" stopped'.format(name, app.name))
351378

352379
def link_all_extensions(self, serverapp):
353380
"""Link all enabled extensions
@@ -366,3 +393,10 @@ def load_all_extensions(self, serverapp):
366393
# order.
367394
for name in self.sorted_extensions.keys():
368395
self.load_extension(name, serverapp)
396+
397+
async def stop_all_extensions(self, serverapp):
398+
"""Call the shutdown hooks in all extensions."""
399+
await multi([
400+
self.stop_extension(name, apps)
401+
for name, apps in sorted(dict(self.extension_apps).items())
402+
])

jupyter_server/pytest_plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from jupyter_server.extension import serverextension
2020
from jupyter_server.serverapp import ServerApp
21-
from jupyter_server.utils import url_path_join
21+
from jupyter_server.utils import url_path_join, run_sync
2222
from jupyter_server.services.contents.filemanager import FileContentsManager
2323
from jupyter_server.services.contents.largefilemanager import LargeFileManager
2424

@@ -284,7 +284,7 @@ def jp_serverapp(
284284
"""Starts a Jupyter Server instance based on the established configuration values."""
285285
app = jp_configurable_serverapp(config=jp_server_config, argv=jp_argv)
286286
yield app
287-
app._cleanup()
287+
run_sync(app._cleanup())
288288

289289

290290
@pytest.fixture

jupyter_server/serverapp.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444
from jupyter_core.paths import secure_write
4545
from jupyter_server.transutils import trans, _i18n
46-
from jupyter_server.utils import run_sync
46+
from jupyter_server.utils import run_sync_in_loop
4747

4848
# the minimum viable tornado version: needs to be kept in sync with setup.py
4949
MIN_TORNADO = (6, 1, 0)
@@ -1777,7 +1777,7 @@ def _confirm_exit(self):
17771777
self.log.critical(_i18n("Shutting down..."))
17781778
# schedule stop on the main thread,
17791779
# since this might be called from a signal handler
1780-
self.io_loop.add_callback_from_signal(self.io_loop.stop)
1780+
self.stop(from_signal=True)
17811781
return
17821782
print(self.running_server_info())
17831783
yes = _i18n('y')
@@ -1791,7 +1791,7 @@ def _confirm_exit(self):
17911791
self.log.critical(_i18n("Shutdown confirmed"))
17921792
# schedule stop on the main thread,
17931793
# since this might be called from a signal handler
1794-
self.io_loop.add_callback_from_signal(self.io_loop.stop)
1794+
self.stop(from_signal=True)
17951795
return
17961796
else:
17971797
print(_i18n("No answer for 5s:"), end=' ')
@@ -1804,7 +1804,7 @@ def _confirm_exit(self):
18041804

18051805
def _signal_stop(self, sig, frame):
18061806
self.log.critical(_i18n("received signal %s, stopping"), sig)
1807-
self.io_loop.add_callback_from_signal(self.io_loop.stop)
1807+
self.stop(from_signal=True)
18081808

18091809
def _signal_info(self, sig, frame):
18101810
print(self.running_server_info())
@@ -2086,7 +2086,7 @@ def initialize(self, argv=None, find_extensions=True, new_httpserver=True, start
20862086
if new_httpserver:
20872087
self.init_httpserver()
20882088

2089-
def cleanup_kernels(self):
2089+
async def cleanup_kernels(self):
20902090
"""Shutdown all kernels.
20912091
20922092
The kernels will shutdown themselves when this process no longer exists,
@@ -2095,9 +2095,9 @@ def cleanup_kernels(self):
20952095
n_kernels = len(self.kernel_manager.list_kernel_ids())
20962096
kernel_msg = trans.ngettext('Shutting down %d kernel', 'Shutting down %d kernels', n_kernels)
20972097
self.log.info(kernel_msg % n_kernels)
2098-
run_sync(self.kernel_manager.shutdown_all())
2098+
await run_sync_in_loop(self.kernel_manager.shutdown_all())
20992099

2100-
def cleanup_terminals(self):
2100+
async def cleanup_terminals(self):
21012101
"""Shutdown all terminals.
21022102
21032103
The terminals will shutdown themselves when this process no longer exists,
@@ -2110,7 +2110,20 @@ def cleanup_terminals(self):
21102110
n_terminals = len(terminal_manager.list())
21112111
terminal_msg = trans.ngettext('Shutting down %d terminal', 'Shutting down %d terminals', n_terminals)
21122112
self.log.info(terminal_msg % n_terminals)
2113-
run_sync(terminal_manager.terminate_all())
2113+
await run_sync_in_loop(terminal_manager.terminate_all())
2114+
2115+
async def cleanup_extensions(self):
2116+
"""Call shutdown hooks in all extensions."""
2117+
n_extensions = len(self.extension_manager.extension_apps)
2118+
extension_msg = trans.ngettext(
2119+
'Shutting down %d extension',
2120+
'Shutting down %d extensions',
2121+
n_extensions
2122+
)
2123+
self.log.info(extension_msg % n_extensions)
2124+
await run_sync_in_loop(
2125+
self.extension_manager.stop_all_extensions(self)
2126+
)
21142127

21152128
def running_server_info(self, kernel_count=True):
21162129
"Return the current working directory and the server url information"
@@ -2348,14 +2361,15 @@ def start_app(self):
23482361
' %s' % self.display_url,
23492362
]))
23502363

2351-
def _cleanup(self):
2352-
"""General cleanup of files and kernels created
2364+
async def _cleanup(self):
2365+
"""General cleanup of files, extensions and kernels created
23532366
by this instance ServerApp.
23542367
"""
23552368
self.remove_server_info_file()
23562369
self.remove_browser_open_files()
2357-
self.cleanup_kernels()
2358-
self.cleanup_terminals()
2370+
await self.cleanup_extensions()
2371+
await self.cleanup_kernels()
2372+
await self.cleanup_terminals()
23592373

23602374
def start_ioloop(self):
23612375
"""Start the IO Loop."""
@@ -2368,8 +2382,6 @@ def start_ioloop(self):
23682382
self.io_loop.start()
23692383
except KeyboardInterrupt:
23702384
self.log.info(_i18n("Interrupted..."))
2371-
finally:
2372-
self._cleanup()
23732385

23742386
def init_ioloop(self):
23752387
"""init self.io_loop so that an extension can use it by io_loop.call_later() to create background tasks"""
@@ -2383,13 +2395,23 @@ def start(self):
23832395
self.start_app()
23842396
self.start_ioloop()
23852397

2386-
def stop(self):
2387-
def _stop():
2398+
async def _stop(self):
2399+
"""Cleanup resources and stop the IO Loop."""
2400+
await self._cleanup()
2401+
self.io_loop.stop()
2402+
2403+
def stop(self, from_signal=False):
2404+
"""Cleanup resources and stop the server."""
2405+
if hasattr(self, '_http_server'):
23882406
# Stop a server if its set.
2389-
if hasattr(self, '_http_server'):
2390-
self.http_server.stop()
2391-
self.io_loop.stop()
2392-
self.io_loop.add_callback(_stop)
2407+
self.http_server.stop()
2408+
if getattr(self, 'io_loop', None):
2409+
# use IOLoop.add_callback because signal.signal must be called
2410+
# from main thread
2411+
if from_signal:
2412+
self.io_loop.add_callback_from_signal(self._stop)
2413+
else:
2414+
self.io_loop.add_callback(self._stop)
23932415

23942416

23952417
def list_running_servers(runtime_dir=None):

jupyter_server/tests/extension/test_app.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
from traitlets.config import Config
33
from jupyter_server.serverapp import ServerApp
4+
from jupyter_server.utils import run_sync
45
from .mockextensions.app import MockExtensionApp
56

67

@@ -101,3 +102,42 @@ def test_load_parallel_extensions(monkeypatch, jp_environ):
101102
exts = serverapp.jpserver_extensions
102103
assert exts['jupyter_server.tests.extension.mockextensions.mock1']
103104
assert exts['jupyter_server.tests.extension.mockextensions']
105+
106+
107+
def test_stop_extension(jp_serverapp, caplog):
108+
"""Test the stop_extension method.
109+
110+
This should be fired by ServerApp.cleanup_extensions.
111+
"""
112+
calls = 0
113+
114+
# load extensions (make sure we only have the one extension loaded
115+
jp_serverapp.extension_manager.load_all_extensions(jp_serverapp)
116+
extension_name = 'jupyter_server.tests.extension.mockextensions'
117+
assert list(jp_serverapp.extension_manager.extension_apps) == [
118+
extension_name
119+
]
120+
121+
# add a stop_extension method for the extension app
122+
async def _stop(*args):
123+
nonlocal calls
124+
calls += 1
125+
for apps in jp_serverapp.extension_manager.extension_apps.values():
126+
for app in apps:
127+
if app:
128+
app.stop_extension = _stop
129+
130+
# call cleanup_extensions, check the logging is correct
131+
caplog.clear()
132+
run_sync(jp_serverapp.cleanup_extensions())
133+
assert [
134+
msg
135+
for *_, msg in caplog.record_tuples
136+
] == [
137+
'Shutting down 1 extension',
138+
'{} | extension app "mockextension" stopping'.format(extension_name),
139+
'{} | extension app "mockextension" stopped'.format(extension_name),
140+
]
141+
142+
# check the shutdown method was called once
143+
assert calls == 1

jupyter_server/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,29 @@ def wrapped():
232232
return wrapped()
233233

234234

235+
async def run_sync_in_loop(maybe_async):
236+
"""Runs a function synchronously whether it is an async function or not.
237+
238+
If async, runs maybe_async and blocks until it has executed.
239+
240+
If not async, just returns maybe_async as it is the result of something
241+
that has already executed.
242+
243+
Parameters
244+
----------
245+
maybe_async : async or non-async object
246+
The object to be executed, if it is async.
247+
248+
Returns
249+
-------
250+
result
251+
Whatever the async object returns, or the object itself.
252+
"""
253+
if not inspect.isawaitable(maybe_async):
254+
return maybe_async
255+
return await maybe_async
256+
257+
235258
def urlencode_unix_socket_path(socket_path):
236259
"""Encodes a UNIX socket path string from a socket path for the `http+unix` URI form."""
237260
return socket_path.replace('/', '%2F')

0 commit comments

Comments
 (0)