Skip to content

Commit 8f99062

Browse files
kevin-batesKevin Bates
andauthored
Fix gateway cookie handling (#1558)
Co-authored-by: Kevin Bates <[email protected]>
1 parent b847905 commit 8f99062

File tree

3 files changed

+67
-38
lines changed

3 files changed

+67
-38
lines changed

jupyter_server/gateway/gateway_client.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
from abc import ABC, ABCMeta, abstractmethod
1313
from datetime import datetime, timezone
1414
from email.utils import parsedate_to_datetime
15-
from http.cookies import SimpleCookie
15+
from http.cookies import Morsel, SimpleCookie
1616
from socket import gaierror
1717

1818
from jupyter_events import EventLogger
1919
from tornado import web
2020
from tornado.httpclient import AsyncHTTPClient, HTTPClientError, HTTPResponse
21+
from tornado.httputil import HTTPHeaders
2122
from traitlets import (
2223
Bool,
2324
Float,
@@ -40,9 +41,6 @@
4041
STATUS_CODE_KEY = "status_code"
4142
MESSAGE_KEY = "msg"
4243

43-
if ty.TYPE_CHECKING:
44-
from http.cookies import Morsel
45-
4644

4745
class GatewayTokenRenewerMeta(ABCMeta, type(LoggingConfigurable)): # type: ignore[misc]
4846
"""The metaclass necessary for proper ABC behavior in a Configurable."""
@@ -630,20 +628,41 @@ def load_connection_args(self, **kwargs):
630628

631629
return kwargs
632630

633-
def update_cookies(self, cookie: SimpleCookie) -> None:
634-
"""Update cookies from existing requests for load balancers"""
631+
def update_cookies(self, headers: HTTPHeaders) -> None:
632+
"""Update cookies from response headers"""
633+
635634
if not self.accept_cookies:
636635
return
637636

637+
# Get individual Set-Cookie headers in list form. This handles multiple cookies
638+
# that are otherwise comma-separated in the header and will break the parsing logic
639+
# if only headers.get() is used.
640+
cookie_headers = headers.get_list("Set-Cookie")
641+
if not cookie_headers:
642+
return
643+
638644
store_time = datetime.now(tz=timezone.utc)
639-
for key, item in cookie.items():
645+
for header in cookie_headers:
646+
cookie = SimpleCookie()
647+
try:
648+
cookie.load(header)
649+
except Exception as e:
650+
self.log.warning("Failed to parse cookie header %s: %s", header, e)
651+
continue
652+
653+
if not cookie:
654+
self.log.warning("No cookies found in header: %s", header)
655+
continue
656+
name, morsel = next(iter(cookie.items()))
657+
640658
# Convert "expires" arg into "max-age" to facilitate expiration management.
641659
# As "max-age" has precedence, ignore "expires" when "max-age" exists.
642-
if item.get("expires") and not item.get("max-age"):
643-
expire_timedelta = parsedate_to_datetime(item["expires"]) - store_time
644-
item["max-age"] = str(expire_timedelta.total_seconds())
660+
if morsel.get("expires") and not morsel.get("max-age"):
661+
expire_time = parsedate_to_datetime(morsel["expires"])
662+
expire_timedelta = expire_time - store_time
663+
morsel["max-age"] = str(expire_timedelta.total_seconds())
645664

646-
self._cookies[key] = (item, store_time)
665+
self._cookies[name] = (morsel, store_time)
647666

648667
def _clear_expired_cookies(self) -> None:
649668
"""Clear expired cookies."""
@@ -821,10 +840,6 @@ async def gateway_request(endpoint: str, **kwargs: ty.Any) -> HTTPResponse:
821840
raise e
822841

823842
if gateway_client.accept_cookies:
824-
# Update cookies on GatewayClient from server if configured.
825-
cookie_values = response.headers.get("Set-Cookie")
826-
if cookie_values:
827-
cookie: SimpleCookie = SimpleCookie()
828-
cookie.load(cookie_values)
829-
gateway_client.update_cookies(cookie)
843+
gateway_client.update_cookies(response.headers)
844+
830845
return response

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ unfixable = [
156156
[tool.ruff.lint.extend-per-file-ignores]
157157
"jupyter_server/*" = ["S101", "RET", "S110", "UP031", "FBT", "FA100", "SLF001", "A002",
158158
"SIM105", "A001", "UP007", "PLR2004", "T201", "N818", "F403"]
159+
"jupyter_server/gateway/*" = ["TCH" ]
159160
"tests/*" = ["UP031", "PT", 'EM', "TRY", "RET", "SLF", "C408", "F841", "FBT", "A002", "FLY", "N",
160161
"PERF", "ASYNC", "T201", "FA100", "E711", "INP", "TCH", "SIM105", "A001", "PLW0603"]
161162
"examples/*_config.py" = ["F821"]

tests/test_gateway.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import os
77
import uuid
88
from datetime import datetime, timedelta, timezone
9-
from email.utils import format_datetime
10-
from http.cookies import SimpleCookie
119
from io import BytesIO
1210
from queue import Empty
1311
from typing import Any, Union
@@ -18,7 +16,7 @@
1816
from jupyter_core.utils import ensure_async
1917
from tornado.concurrent import Future
2018
from tornado.httpclient import HTTPRequest, HTTPResponse
21-
from tornado.httputil import HTTPServerRequest
19+
from tornado.httputil import HTTPHeaders, HTTPServerRequest
2220
from tornado.queues import Queue
2321
from tornado.web import HTTPError
2422
from traitlets import Int, Unicode
@@ -376,12 +374,12 @@ def test_gateway_request_timeout_pad_option(
376374

377375

378376
@pytest.mark.parametrize(
379-
"accept_cookies,expire_arg,expire_param,existing_cookies,cookie_exists",
377+
"accept_cookies,expire_arg,expire_param,existing_cookies",
380378
[
381-
(False, None, None, "EXISTING=1", False),
382-
(True, None, None, "EXISTING=1", True),
383-
(True, "Expires", 180, None, True),
384-
(True, "Max-Age", "-360", "EXISTING=1", False),
379+
(False, None, 0, "EXISTING=1"),
380+
(True, None, 0, "EXISTING=1"),
381+
(True, "expires", 180, None),
382+
(True, "Max-Age", -360, "EXISTING=1"),
385383
],
386384
)
387385
def test_gateway_request_with_expiring_cookies(
@@ -390,35 +388,50 @@ def test_gateway_request_with_expiring_cookies(
390388
expire_arg,
391389
expire_param,
392390
existing_cookies,
393-
cookie_exists,
394391
):
395392
argv = [f"--GatewayClient.accept_cookies={accept_cookies}"]
396393

397394
GatewayClient.clear_instance()
398395
_ = jp_configurable_serverapp(argv=argv)
399396

400-
cookie: SimpleCookie = SimpleCookie()
401-
cookie.load("SERVERID=1234567; Path=/")
402-
if expire_arg == "Expires":
403-
expire_param = format_datetime(
404-
datetime.now(tz=timezone.utc) + timedelta(seconds=expire_param)
397+
test_expiration = bool(expire_param < 0)
398+
# Create mock headers with Set-Cookie values
399+
headers = HTTPHeaders()
400+
cookie_value = "SERVERID=1234567; Path=/; HttpOnly"
401+
if expire_arg == "expires":
402+
# Convert expire_param to a string in the format of "Expires: <date>" (RFC 7231)
403+
expire_param = (datetime.now(tz=timezone.utc) + timedelta(seconds=expire_param)).strftime(
404+
"%a, %d %b %Y %H:%M:%S GMT"
405405
)
406-
if expire_arg:
407-
cookie["SERVERID"][expire_arg] = expire_param
406+
cookie_value = f"SERVERID=1234567; Path=/; expires={expire_param}; HttpOnly"
407+
elif expire_arg == "Max-Age":
408+
cookie_value = f"SERVERID=1234567; Path=/; Max-Age={expire_param}; HttpOnly"
408409

409-
GatewayClient.instance().update_cookies(cookie)
410+
# Add a second cookie to test comma-separated cookies
411+
headers.add("Set-Cookie", cookie_value)
412+
headers.add("Set-Cookie", "ADDITIONAL_COOKIE=8901234; Path=/; HttpOnly")
413+
414+
headers_list = headers.get_list("Set-Cookie")
415+
print(headers_list)
416+
417+
GatewayClient.instance().update_cookies(headers)
410418

411419
args = {}
412420
if existing_cookies:
413421
args["headers"] = {"Cookie": existing_cookies}
422+
414423
connection_args = GatewayClient.instance().load_connection_args(**args)
415424

416-
if not cookie_exists:
417-
assert "SERVERID" not in (connection_args["headers"].get("Cookie") or "")
425+
if not accept_cookies or test_expiration:
426+
# The first condition ensure the response cookie is not accepted,
427+
# the second condition ensures that the cookie is not accepted if it is expired.
428+
assert "SERVERID" not in connection_args["headers"].get("Cookie")
418429
else:
419-
assert "SERVERID" in connection_args["headers"].get("Cookie")
430+
# The cookie is accepted if it is not expired and accept_cookies is True.
431+
assert "SERVERID" in connection_args["headers"].get("Cookie", "")
432+
420433
if existing_cookies:
421-
assert "EXISTING" in connection_args["headers"].get("Cookie")
434+
assert "EXISTING" in connection_args["headers"].get("Cookie", "")
422435

423436
GatewayClient.clear_instance()
424437

0 commit comments

Comments
 (0)