Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 32 additions & 17 deletions jupyter_server/gateway/gateway_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
from abc import ABC, ABCMeta, abstractmethod
from datetime import datetime, timezone
from email.utils import parsedate_to_datetime
from http.cookies import SimpleCookie
from http.cookies import Morsel, SimpleCookie
from socket import gaierror

from jupyter_events import EventLogger
from tornado import web
from tornado.httpclient import AsyncHTTPClient, HTTPClientError, HTTPResponse
from tornado.httputil import HTTPHeaders
from traitlets import (
Bool,
Float,
Expand All @@ -40,9 +41,6 @@
STATUS_CODE_KEY = "status_code"
MESSAGE_KEY = "msg"

if ty.TYPE_CHECKING:
from http.cookies import Morsel


class GatewayTokenRenewerMeta(ABCMeta, type(LoggingConfigurable)): # type: ignore[misc]
"""The metaclass necessary for proper ABC behavior in a Configurable."""
Expand Down Expand Up @@ -630,20 +628,41 @@ def load_connection_args(self, **kwargs):

return kwargs

def update_cookies(self, cookie: SimpleCookie) -> None:
"""Update cookies from existing requests for load balancers"""
def update_cookies(self, headers: HTTPHeaders) -> None:
"""Update cookies from response headers"""

if not self.accept_cookies:
return

# Get individual Set-Cookie headers in list form. This handles multiple cookies
# that are otherwise comma-separated in the header and will break the parsing logic
# if only headers.get() is used.
cookie_headers = headers.get_list("Set-Cookie")
if not cookie_headers:
return

store_time = datetime.now(tz=timezone.utc)
for key, item in cookie.items():
for header in cookie_headers:
cookie = SimpleCookie()
try:
cookie.load(header)
except Exception as e:
self.log.warning("Failed to parse cookie header %s: %s", header, e)
continue

if not cookie:
self.log.warning("No cookies found in header: %s", header)
continue
name, morsel = next(iter(cookie.items()))

# Convert "expires" arg into "max-age" to facilitate expiration management.
# As "max-age" has precedence, ignore "expires" when "max-age" exists.
if item.get("expires") and not item.get("max-age"):
expire_timedelta = parsedate_to_datetime(item["expires"]) - store_time
item["max-age"] = str(expire_timedelta.total_seconds())
if morsel.get("expires") and not morsel.get("max-age"):
expire_time = parsedate_to_datetime(morsel["expires"])
expire_timedelta = expire_time - store_time
morsel["max-age"] = str(expire_timedelta.total_seconds())

self._cookies[key] = (item, store_time)
self._cookies[name] = (morsel, store_time)

def _clear_expired_cookies(self) -> None:
"""Clear expired cookies."""
Expand Down Expand Up @@ -821,10 +840,6 @@ async def gateway_request(endpoint: str, **kwargs: ty.Any) -> HTTPResponse:
raise e

if gateway_client.accept_cookies:
# Update cookies on GatewayClient from server if configured.
cookie_values = response.headers.get("Set-Cookie")
if cookie_values:
cookie: SimpleCookie = SimpleCookie()
cookie.load(cookie_values)
gateway_client.update_cookies(cookie)
gateway_client.update_cookies(response.headers)

return response
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ unfixable = [
[tool.ruff.lint.extend-per-file-ignores]
"jupyter_server/*" = ["S101", "RET", "S110", "UP031", "FBT", "FA100", "SLF001", "A002",
"SIM105", "A001", "UP007", "PLR2004", "T201", "N818", "F403"]
"jupyter_server/gateway/*" = ["TCH" ]
"tests/*" = ["UP031", "PT", 'EM', "TRY", "RET", "SLF", "C408", "F841", "FBT", "A002", "FLY", "N",
"PERF", "ASYNC", "T201", "FA100", "E711", "INP", "TCH", "SIM105", "A001", "PLW0603"]
"examples/*_config.py" = ["F821"]
Expand Down
55 changes: 34 additions & 21 deletions tests/test_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import os
import uuid
from datetime import datetime, timedelta, timezone
from email.utils import format_datetime
from http.cookies import SimpleCookie
from io import BytesIO
from queue import Empty
from typing import Any, Union
Expand All @@ -18,7 +16,7 @@
from jupyter_core.utils import ensure_async
from tornado.concurrent import Future
from tornado.httpclient import HTTPRequest, HTTPResponse
from tornado.httputil import HTTPServerRequest
from tornado.httputil import HTTPHeaders, HTTPServerRequest
from tornado.queues import Queue
from tornado.web import HTTPError
from traitlets import Int, Unicode
Expand Down Expand Up @@ -376,12 +374,12 @@ def test_gateway_request_timeout_pad_option(


@pytest.mark.parametrize(
"accept_cookies,expire_arg,expire_param,existing_cookies,cookie_exists",
"accept_cookies,expire_arg,expire_param,existing_cookies",
[
(False, None, None, "EXISTING=1", False),
(True, None, None, "EXISTING=1", True),
(True, "Expires", 180, None, True),
(True, "Max-Age", "-360", "EXISTING=1", False),
(False, None, 0, "EXISTING=1"),
(True, None, 0, "EXISTING=1"),
(True, "expires", 180, None),
(True, "Max-Age", -360, "EXISTING=1"),
],
)
def test_gateway_request_with_expiring_cookies(
Expand All @@ -390,35 +388,50 @@ def test_gateway_request_with_expiring_cookies(
expire_arg,
expire_param,
existing_cookies,
cookie_exists,
):
argv = [f"--GatewayClient.accept_cookies={accept_cookies}"]

GatewayClient.clear_instance()
_ = jp_configurable_serverapp(argv=argv)

cookie: SimpleCookie = SimpleCookie()
cookie.load("SERVERID=1234567; Path=/")
if expire_arg == "Expires":
expire_param = format_datetime(
datetime.now(tz=timezone.utc) + timedelta(seconds=expire_param)
test_expiration = bool(expire_param < 0)
# Create mock headers with Set-Cookie values
headers = HTTPHeaders()
cookie_value = "SERVERID=1234567; Path=/; HttpOnly"
if expire_arg == "expires":
# Convert expire_param to a string in the format of "Expires: <date>" (RFC 7231)
expire_param = (datetime.now(tz=timezone.utc) + timedelta(seconds=expire_param)).strftime(
"%a, %d %b %Y %H:%M:%S GMT"
)
if expire_arg:
cookie["SERVERID"][expire_arg] = expire_param
cookie_value = f"SERVERID=1234567; Path=/; expires={expire_param}; HttpOnly"
elif expire_arg == "Max-Age":
cookie_value = f"SERVERID=1234567; Path=/; Max-Age={expire_param}; HttpOnly"

GatewayClient.instance().update_cookies(cookie)
# Add a second cookie to test comma-separated cookies
headers.add("Set-Cookie", cookie_value)
headers.add("Set-Cookie", "ADDITIONAL_COOKIE=8901234; Path=/; HttpOnly")

headers_list = headers.get_list("Set-Cookie")
print(headers_list)

GatewayClient.instance().update_cookies(headers)

args = {}
if existing_cookies:
args["headers"] = {"Cookie": existing_cookies}

connection_args = GatewayClient.instance().load_connection_args(**args)

if not cookie_exists:
assert "SERVERID" not in (connection_args["headers"].get("Cookie") or "")
if not accept_cookies or test_expiration:
# The first condition ensure the response cookie is not accepted,
# the second condition ensures that the cookie is not accepted if it is expired.
assert "SERVERID" not in connection_args["headers"].get("Cookie")
else:
assert "SERVERID" in connection_args["headers"].get("Cookie")
# The cookie is accepted if it is not expired and accept_cookies is True.
assert "SERVERID" in connection_args["headers"].get("Cookie", "")

if existing_cookies:
assert "EXISTING" in connection_args["headers"].get("Cookie")
assert "EXISTING" in connection_args["headers"].get("Cookie", "")

GatewayClient.clear_instance()

Expand Down
Loading