Skip to content

Commit 050f988

Browse files
kevin-batesZsailer
authored andcommitted
Add GatewayIdentityProvider, get tests working
1 parent 3a7e9b4 commit 050f988

File tree

5 files changed

+176
-55
lines changed

5 files changed

+176
-55
lines changed

conftest.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
# Copyright (c) Jupyter Development Team.
22
# Distributed under the terms of the Modified BSD License.
33

4+
import os
45
import logging
56
import pytest
6-
7-
from jupyter_server import version_info as jupyter_server_version
8-
7+
from binascii import hexlify
8+
from traitlets.config import Config
99
from kernel_gateway.gatewayapp import KernelGatewayApp
1010

11-
is_v2 = jupyter_server_version[0] == 2
12-
1311
pytest_plugins = ["pytest_jupyter.jupyter_core", "pytest_jupyter.jupyter_server"]
1412

1513

@@ -45,20 +43,22 @@ def _configurable_serverapp(
4543
config=jp_server_config,
4644
base_url=jp_base_url,
4745
argv=jp_argv,
48-
environ=jp_environ,
4946
http_port=jp_http_port,
50-
tmp_path=tmp_path,
51-
io_loop=io_loop,
52-
root_dir=jp_root_dir,
5347
**kwargs,
5448
):
49+
c = Config(config)
50+
51+
if "auth_token" not in c.KernelGatewayApp and not c.IdentityProvider.token:
52+
default_token = hexlify(os.urandom(4)).decode("ascii")
53+
c.IdentityProvider.token = default_token
54+
5555
app = KernelGatewayApp.instance(
5656
# Set the log level to debug for testing purposes
5757
log_level="DEBUG",
58-
port=jp_http_port,
58+
port=http_port,
5959
port_retries=0,
6060
base_url=base_url,
61-
config=config,
61+
config=c,
6262
**kwargs,
6363
)
6464
app.log.propagate = True
@@ -101,6 +101,4 @@ def jp_server_cleanup(jp_asyncio_loop):
101101
@pytest.fixture
102102
def jp_auth_header(jp_serverapp):
103103
"""Configures an authorization header using the token from the serverapp fixture."""
104-
if not is_v2:
105-
return {"Authorization": f"token {jp_serverapp.auth_token}"}
106104
return {"Authorization": f"token {jp_serverapp.identity_provider.token}"}

kernel_gateway/auth/identity.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) Jupyter Development Team.
2+
# Distributed under the terms of the Modified BSD License.
3+
"""Gateway Identity Provider interface
4+
5+
This defines the _authentication_ layer of Jupyter Server,
6+
to be used in combination with Authorizer for _authorization_.
7+
"""
8+
from traitlets import default
9+
10+
from jupyter_server.auth.identity import IdentityProvider
11+
from jupyter_server.base.handlers import JupyterHandler
12+
13+
14+
class GatewayIdentityProvider(IdentityProvider):
15+
"""
16+
Interface for providing identity management and authentication for a Gateway server.
17+
"""
18+
@default("token")
19+
def _token_default(self):
20+
# if the superclass generated a token, but auth_token is configured on
21+
# the Gateway server, reset token_generated and use the configured value.
22+
token_default = super()._token_default()
23+
if self.token_generated and self.parent.auth_token:
24+
self.token_generated = False
25+
return self.parent.auth_token
26+
return token_default
27+
28+
def should_check_origin(self, handler: JupyterHandler) -> bool:
29+
"""Should the Handler check for CORS origin validation?
30+
31+
Origin check should be skipped for token-authenticated requests.
32+
33+
Returns:
34+
- True, if Handler must check for valid CORS origin.
35+
- False, if Handler should skip origin check since requests are token-authenticated.
36+
"""
37+
# Always check the origin unless operator configured gateway to allow any
38+
return handler.settings["kg_allow_origin"] != "*"

kernel_gateway/gatewayapp.py

Lines changed: 88 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,42 +6,46 @@
66
import errno
77
import importlib
88
import logging
9+
import hashlib
10+
import hmac
911
import os
1012
import sys
1113
import signal
1214
import select
1315
import socket
1416
import ssl
1517
import threading
18+
from base64 import encodebytes
1619
from distutils.util import strtobool
1720

1821
import nbformat
1922
from jupyter_server.services.kernels.kernelmanager import MappingKernelManager
2023

2124
from urllib.parse import urlparse
22-
from traitlets import Unicode, Integer, default, observe, Type, Instance, List, CBool
25+
from traitlets import Unicode, Integer, Bytes, default, observe, Type, Instance, List, CBool
2326

2427
from jupyter_core.application import JupyterApp, base_aliases
25-
from jupyter_core.utils import run_sync
2628
from jupyter_client.kernelspec import KernelSpecManager
2729

2830
from tornado import httpserver
2931
from tornado import web, ioloop
3032
from tornado.log import enable_pretty_logging, LogFormatter
3133

34+
from jupyter_core.paths import secure_write
3235
from jupyter_server.serverapp import random_ports
3336
from ._version import __version__
3437
from .services.sessions.sessionmanager import SessionManager
3538
from .services.kernels.manager import SeedingMappingKernelManager
3639

40+
from .auth.identity import GatewayIdentityProvider
41+
3742
# Only present for generating help documentation
3843
from .notebook_http import NotebookHTTPPersonality
3944
from .jupyter_websocket import JupyterWebsocketPersonality
4045

41-
try:
42-
from jupyter_server.auth.authorizer import AllowAllAuthorizer
43-
except ImportError:
44-
AllowAllAuthorizer = object
46+
from jupyter_server.auth.authorizer import AllowAllAuthorizer, Authorizer
47+
from jupyter_server.services.kernels.connection.base import BaseKernelWebsocketConnection
48+
from jupyter_server.services.kernels.connection.channels import ZMQChannelsWebsocketConnection
4549

4650

4751
# Add additional command line aliases
@@ -311,6 +315,51 @@ def ssl_version_default(self):
311315
ssl_from_env = os.getenv(self.ssl_version_env)
312316
return ssl_from_env if ssl_from_env is None else int(ssl_from_env)
313317

318+
cookie_secret_file = Unicode(
319+
config=True, help="""The file where the cookie secret is stored."""
320+
)
321+
322+
@default("cookie_secret_file")
323+
def _default_cookie_secret_file(self):
324+
return os.path.join(self.runtime_dir, "jupyter_cookie_secret")
325+
326+
cookie_secret = Bytes(
327+
b"",
328+
config=True,
329+
help="""The random bytes used to secure cookies.
330+
By default this is a new random number every time you start the server.
331+
Set it to a value in a config file to enable logins to persist across server sessions.
332+
333+
Note: Cookie secrets should be kept private, do not share config files with
334+
cookie_secret stored in plaintext (you can read the value from a file).
335+
""",
336+
)
337+
338+
@default("cookie_secret")
339+
def _default_cookie_secret(self):
340+
if os.path.exists(self.cookie_secret_file):
341+
with open(self.cookie_secret_file, "rb") as f:
342+
key = f.read()
343+
else:
344+
key = encodebytes(os.urandom(32))
345+
self._write_cookie_secret_file(key)
346+
h = hmac.new(key, digestmod=hashlib.sha256)
347+
# h.update(self.password.encode()) # password is deprecated in 2.0
348+
return h.digest()
349+
350+
def _write_cookie_secret_file(self, secret):
351+
"""write my secret to my secret_file"""
352+
self.log.info("Writing Jupyter server cookie secret to %s", self.cookie_secret_file)
353+
try:
354+
with secure_write(self.cookie_secret_file, True) as f:
355+
f.write(secret)
356+
except OSError as e:
357+
self.log.error(
358+
"Failed to write cookie secret to %s: %s",
359+
self.cookie_secret_file,
360+
e,
361+
)
362+
314363
ws_ping_interval_env = "KG_WS_PING_INTERVAL_SECS"
315364
ws_ping_interval_default_value = 30
316365
ws_ping_interval = Integer(
@@ -352,6 +401,27 @@ def _default_log_format(self) -> str:
352401
help="""The kernel manager class to use."""
353402
)
354403

404+
kernel_websocket_connection_class = Type(
405+
default_value=ZMQChannelsWebsocketConnection,
406+
klass=BaseKernelWebsocketConnection,
407+
config=True,
408+
help="""The kernel websocket connection class to use.""",
409+
)
410+
411+
authorizer_class = Type(
412+
default_value=AllowAllAuthorizer,
413+
klass=Authorizer,
414+
config=True,
415+
help="The authorizer class to use.",
416+
)
417+
418+
identity_provider_class = Type(
419+
default_value=GatewayIdentityProvider,
420+
klass=GatewayIdentityProvider,
421+
config=True,
422+
help="The identity provider class to use.",
423+
)
424+
355425
def _load_api_module(self, module_name):
356426
"""Tries to import the given module name.
357427
@@ -467,6 +537,12 @@ def init_configurables(self):
467537
)
468538
self.contents_manager = None
469539

540+
self.identity_provider = self.identity_provider_class(parent=self, log=self.log)
541+
542+
self.authorizer = self.authorizer_class(
543+
parent=self, log=self.log, identity_provider=self.identity_provider
544+
)
545+
470546
if self.prespawn_count:
471547
if self.max_kernels and self.prespawn_count > self.max_kernels:
472548
msg = f"Cannot prespawn {self.prespawn_count} kernels; more than max kernels {self.max_kernels}"
@@ -514,13 +590,17 @@ def init_webapp(self):
514590
allow_origin=self.allow_origin,
515591
# Set base_url for use in request handlers
516592
base_url=self.base_url,
593+
# Authentication
594+
cookie_secret=self.cookie_secret,
517595
# Always allow remote access (has been limited to localhost >= notebook 5.6)
518596
allow_remote_access=True,
519597
# setting ws_ping_interval value that can allow it to be modified for the purpose of toggling ping mechanism
520598
# for zmq web-sockets or increasing/decreasing web socket ping interval/timeouts.
521599
ws_ping_interval=self.ws_ping_interval * 1000,
522600
# Add a pass-through authorizer for now
523-
authorizer=AllowAllAuthorizer(),
601+
authorizer=self.authorizer_class(parent=self),
602+
identity_provider=self.identity_provider,
603+
kernel_websocket_connection_class=self.kernel_websocket_connection_class,
524604
)
525605

526606
# promote the current personality's "config" tagged traitlet values to webapp settings
@@ -684,6 +764,7 @@ def start(self):
684764
async def _stop(self):
685765
"""Cleanup resources and stop the IO Loop."""
686766
await self.personality.shutdown()
767+
await self.kernel_websocket_connection_class.close_all()
687768
if getattr(self, "io_loop", None):
688769
self.io_loop.stop()
689770

kernel_gateway/mixins.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,25 @@ class CORSMixin(object):
1515
'kg_allow_headers': 'Access-Control-Allow-Headers',
1616
'kg_allow_methods': 'Access-Control-Allow-Methods',
1717
'kg_allow_origin': 'Access-Control-Allow-Origin',
18-
'kg_expose_headers' : 'Access-Control-Expose-Headers',
18+
'kg_expose_headers': 'Access-Control-Expose-Headers',
1919
'kg_max_age': 'Access-Control-Max-Age'
2020
}
2121

22-
def set_default_headers(self):
22+
def set_cors_headers(self):
2323
"""Sets the CORS headers as the default for all responses.
2424
2525
Disables CSP configured by the notebook package. It's not necessary
2626
for a programmatic API.
27+
28+
Notes
29+
-----
30+
This method name was changed from set_default_headers to set_cors_header
31+
when adding support for JupyterServer 2.x. In that release, JS changed the
32+
way the headers were implemented due to changes in the way a user is authenticated.
33+
See https://github.com/jupyter-server/jupyter_server/pull/671.
2734
"""
28-
super(CORSMixin, self).set_default_headers()
35+
super().set_cors_headers()
36+
2937
# Add CORS headers after default if they have a non-blank value
3038
for settings_name, header_name in self.SETTINGS_TO_HEADERS.items():
3139
header_value = self.settings.get(settings_name)
@@ -75,7 +83,7 @@ def prepare(self):
7583
client_token = None
7684
if client_token != server_token:
7785
return self.send_error(401)
78-
return super(TokenAuthorizationMixin, self).prepare()
86+
return super().prepare()
7987

8088

8189
class JSONErrorsMixin(object):

0 commit comments

Comments
 (0)