|
6 | 6 | import errno |
7 | 7 | import importlib |
8 | 8 | import logging |
| 9 | +import hashlib |
| 10 | +import hmac |
9 | 11 | import os |
10 | 12 | import sys |
11 | 13 | import signal |
12 | 14 | import select |
13 | 15 | import socket |
14 | 16 | import ssl |
15 | 17 | import threading |
| 18 | +from base64 import encodebytes |
16 | 19 | from distutils.util import strtobool |
17 | 20 |
|
18 | 21 | import nbformat |
19 | 22 | from jupyter_server.services.kernels.kernelmanager import MappingKernelManager |
20 | 23 |
|
21 | 24 | from urllib.parse import urlparse |
22 | | -from traitlets import Unicode, Integer, default, observe, Type, Instance, List, CBool |
| 25 | +from traitlets import Unicode, Integer, Bytes, default, observe, Type, Instance, List, CBool |
23 | 26 |
|
24 | 27 | from jupyter_core.application import JupyterApp, base_aliases |
25 | | -from jupyter_core.utils import run_sync |
26 | 28 | from jupyter_client.kernelspec import KernelSpecManager |
27 | 29 |
|
28 | 30 | from tornado import httpserver |
29 | 31 | from tornado import web, ioloop |
30 | 32 | from tornado.log import enable_pretty_logging, LogFormatter |
31 | 33 |
|
| 34 | +from jupyter_core.paths import secure_write |
32 | 35 | from jupyter_server.serverapp import random_ports |
33 | 36 | from ._version import __version__ |
34 | 37 | from .services.sessions.sessionmanager import SessionManager |
35 | 38 | from .services.kernels.manager import SeedingMappingKernelManager |
36 | 39 |
|
| 40 | +from .auth.identity import GatewayIdentityProvider |
| 41 | + |
37 | 42 | # Only present for generating help documentation |
38 | 43 | from .notebook_http import NotebookHTTPPersonality |
39 | 44 | from .jupyter_websocket import JupyterWebsocketPersonality |
40 | 45 |
|
41 | | -try: |
42 | | - from jupyter_server.auth.authorizer import AllowAllAuthorizer |
43 | | -except ImportError: |
44 | | - AllowAllAuthorizer = object |
| 46 | +from jupyter_server.auth.authorizer import AllowAllAuthorizer, Authorizer |
| 47 | +from jupyter_server.services.kernels.connection.base import BaseKernelWebsocketConnection |
| 48 | +from jupyter_server.services.kernels.connection.channels import ZMQChannelsWebsocketConnection |
45 | 49 |
|
46 | 50 |
|
47 | 51 | # Add additional command line aliases |
@@ -311,6 +315,51 @@ def ssl_version_default(self): |
311 | 315 | ssl_from_env = os.getenv(self.ssl_version_env) |
312 | 316 | return ssl_from_env if ssl_from_env is None else int(ssl_from_env) |
313 | 317 |
|
| 318 | + cookie_secret_file = Unicode( |
| 319 | + config=True, help="""The file where the cookie secret is stored.""" |
| 320 | + ) |
| 321 | + |
| 322 | + @default("cookie_secret_file") |
| 323 | + def _default_cookie_secret_file(self): |
| 324 | + return os.path.join(self.runtime_dir, "jupyter_cookie_secret") |
| 325 | + |
| 326 | + cookie_secret = Bytes( |
| 327 | + b"", |
| 328 | + config=True, |
| 329 | + help="""The random bytes used to secure cookies. |
| 330 | + By default this is a new random number every time you start the server. |
| 331 | + Set it to a value in a config file to enable logins to persist across server sessions. |
| 332 | +
|
| 333 | + Note: Cookie secrets should be kept private, do not share config files with |
| 334 | + cookie_secret stored in plaintext (you can read the value from a file). |
| 335 | + """, |
| 336 | + ) |
| 337 | + |
| 338 | + @default("cookie_secret") |
| 339 | + def _default_cookie_secret(self): |
| 340 | + if os.path.exists(self.cookie_secret_file): |
| 341 | + with open(self.cookie_secret_file, "rb") as f: |
| 342 | + key = f.read() |
| 343 | + else: |
| 344 | + key = encodebytes(os.urandom(32)) |
| 345 | + self._write_cookie_secret_file(key) |
| 346 | + h = hmac.new(key, digestmod=hashlib.sha256) |
| 347 | + # h.update(self.password.encode()) # password is deprecated in 2.0 |
| 348 | + return h.digest() |
| 349 | + |
| 350 | + def _write_cookie_secret_file(self, secret): |
| 351 | + """write my secret to my secret_file""" |
| 352 | + self.log.info("Writing Jupyter server cookie secret to %s", self.cookie_secret_file) |
| 353 | + try: |
| 354 | + with secure_write(self.cookie_secret_file, True) as f: |
| 355 | + f.write(secret) |
| 356 | + except OSError as e: |
| 357 | + self.log.error( |
| 358 | + "Failed to write cookie secret to %s: %s", |
| 359 | + self.cookie_secret_file, |
| 360 | + e, |
| 361 | + ) |
| 362 | + |
314 | 363 | ws_ping_interval_env = "KG_WS_PING_INTERVAL_SECS" |
315 | 364 | ws_ping_interval_default_value = 30 |
316 | 365 | ws_ping_interval = Integer( |
@@ -352,6 +401,27 @@ def _default_log_format(self) -> str: |
352 | 401 | help="""The kernel manager class to use.""" |
353 | 402 | ) |
354 | 403 |
|
| 404 | + kernel_websocket_connection_class = Type( |
| 405 | + default_value=ZMQChannelsWebsocketConnection, |
| 406 | + klass=BaseKernelWebsocketConnection, |
| 407 | + config=True, |
| 408 | + help="""The kernel websocket connection class to use.""", |
| 409 | + ) |
| 410 | + |
| 411 | + authorizer_class = Type( |
| 412 | + default_value=AllowAllAuthorizer, |
| 413 | + klass=Authorizer, |
| 414 | + config=True, |
| 415 | + help="The authorizer class to use.", |
| 416 | + ) |
| 417 | + |
| 418 | + identity_provider_class = Type( |
| 419 | + default_value=GatewayIdentityProvider, |
| 420 | + klass=GatewayIdentityProvider, |
| 421 | + config=True, |
| 422 | + help="The identity provider class to use.", |
| 423 | + ) |
| 424 | + |
355 | 425 | def _load_api_module(self, module_name): |
356 | 426 | """Tries to import the given module name. |
357 | 427 |
|
@@ -467,6 +537,12 @@ def init_configurables(self): |
467 | 537 | ) |
468 | 538 | self.contents_manager = None |
469 | 539 |
|
| 540 | + self.identity_provider = self.identity_provider_class(parent=self, log=self.log) |
| 541 | + |
| 542 | + self.authorizer = self.authorizer_class( |
| 543 | + parent=self, log=self.log, identity_provider=self.identity_provider |
| 544 | + ) |
| 545 | + |
470 | 546 | if self.prespawn_count: |
471 | 547 | if self.max_kernels and self.prespawn_count > self.max_kernels: |
472 | 548 | msg = f"Cannot prespawn {self.prespawn_count} kernels; more than max kernels {self.max_kernels}" |
@@ -514,13 +590,17 @@ def init_webapp(self): |
514 | 590 | allow_origin=self.allow_origin, |
515 | 591 | # Set base_url for use in request handlers |
516 | 592 | base_url=self.base_url, |
| 593 | + # Authentication |
| 594 | + cookie_secret=self.cookie_secret, |
517 | 595 | # Always allow remote access (has been limited to localhost >= notebook 5.6) |
518 | 596 | allow_remote_access=True, |
519 | 597 | # setting ws_ping_interval value that can allow it to be modified for the purpose of toggling ping mechanism |
520 | 598 | # for zmq web-sockets or increasing/decreasing web socket ping interval/timeouts. |
521 | 599 | ws_ping_interval=self.ws_ping_interval * 1000, |
522 | 600 | # Add a pass-through authorizer for now |
523 | | - authorizer=AllowAllAuthorizer(), |
| 601 | + authorizer=self.authorizer_class(parent=self), |
| 602 | + identity_provider=self.identity_provider, |
| 603 | + kernel_websocket_connection_class=self.kernel_websocket_connection_class, |
524 | 604 | ) |
525 | 605 |
|
526 | 606 | # promote the current personality's "config" tagged traitlet values to webapp settings |
@@ -684,6 +764,7 @@ def start(self): |
684 | 764 | async def _stop(self): |
685 | 765 | """Cleanup resources and stop the IO Loop.""" |
686 | 766 | await self.personality.shutdown() |
| 767 | + await self.kernel_websocket_connection_class.close_all() |
687 | 768 | if getattr(self, "io_loop", None): |
688 | 769 | self.io_loop.stop() |
689 | 770 |
|
|
0 commit comments