Skip to content

Commit 74f4860

Browse files
ashbjscheffl
andauthored
Re-work JWT Validation and Generation to use public/private key and official claims (apache#46981)
Since requests to Task Execution API can originate out-of-"cluser" so to speak, this PR re-works the JWTSigner class so that it is possible to use public/private keys (vs just a simple pre-shared secret). This is useful in many cloud environments where many companies have security requirements that ingress gateways must validate the JWT tokens on the way in, and the only way of doing this is with public keys So that we don't have two different ways of generating JWT tokens I have totally replaced the old "JWTSigner" class (which it turns out didn't have any unit test of its own, it was only tested indirectly through test_serve_logs etc). As part of this change I have also changed the JWT that was generated by the SimpleAuthManager and the AwsAuthManager (the only two we have that use JWT) to use the offical `sub` (subject) clain to place the user identifer rather than a custom claim name. And although it might seem slightly strange at first, I have made the JWTValidator an async class internally. (Hence `avalidated_claims` -- the `a` prefix signifies async, much like `aclose` or `aread` on HTTPX async responses). This allows us to periodically refresh the JWK document if configured in the background, and using asgiref's async_to_sync means we only have one version. Co-authored-by: Jens Scheffler <jscheffl@apache.org>
1 parent 4b85cb6 commit 74f4860

File tree

34 files changed

+1055
-256
lines changed

34 files changed

+1055
-256
lines changed

airflow/api_fastapi/auth/managers/base_auth_manager.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,24 @@
1919

2020
import logging
2121
from abc import ABCMeta, abstractmethod
22+
from functools import cache
2223
from typing import TYPE_CHECKING, Any, Generic, TypeVar
2324

2425
from jwt import InvalidTokenError
2526
from sqlalchemy import select
2627

2728
from airflow.api_fastapi.auth.managers.models.base_user import BaseUser
2829
from airflow.api_fastapi.auth.managers.models.resource_details import BackfillDetails, DagDetails
30+
from airflow.api_fastapi.auth.tokens import (
31+
JWTGenerator,
32+
JWTValidator,
33+
get_sig_validation_args,
34+
get_signing_args,
35+
)
2936
from airflow.api_fastapi.common.types import ExtraMenuItem, MenuItem
3037
from airflow.configuration import conf
3138
from airflow.models import DagModel
3239
from airflow.typing_compat import Literal
33-
from airflow.utils.jwt_signer import JWTSigner, get_signing_key
3440
from airflow.utils.log.logging_mixin import LoggingMixin
3541
from airflow.utils.session import NEW_SESSION, provide_session
3642

@@ -86,24 +92,24 @@ def deserialize_user(self, token: dict[str, Any]) -> T:
8692

8793
@abstractmethod
8894
def serialize_user(self, user: T) -> dict[str, Any]:
89-
"""Create a dict from a user object."""
95+
"""Create a subject and extra claims dict from a user object."""
9096

91-
def get_user_from_token(self, token: str) -> BaseUser:
97+
async def get_user_from_token(self, token: str) -> BaseUser:
9298
"""Verify the JWT token is valid and create a user object from it if valid."""
9399
try:
94-
payload: dict[str, Any] = self._get_token_signer().verify_token(token)
100+
payload: dict[str, Any] = await self._get_token_validator().avalidated_claims(token)
95101
return self.deserialize_user(payload)
96102
except InvalidTokenError as e:
97-
log.error("JWT token is not valid")
103+
log.error("JWT token is not valid: %s", e)
98104
raise e
99105

100-
def get_jwt_token(
101-
self, user: T, *, expiration_time_in_seconds: int = conf.getint("api", "auth_jwt_expiration_time")
106+
def generate_jwt(
107+
self, user: T, *, expiration_time_in_seconds: int = conf.getint("api_auth", "jwt_expiration_time")
102108
) -> str:
103109
"""Return the JWT token from a user object."""
104-
return self._get_token_signer(
105-
expiration_time_in_seconds=expiration_time_in_seconds
106-
).generate_signed_token(self.serialize_user(user))
110+
return self._get_token_signer(expiration_time_in_seconds=expiration_time_in_seconds).generate(
111+
self.serialize_user(user)
112+
)
107113

108114
@abstractmethod
109115
def get_url_login(self, **kwargs) -> str:
@@ -450,19 +456,35 @@ def get_extra_menu_items(self, *, user: T) -> list[ExtraMenuItem]:
450456
"""
451457
return []
452458

453-
@staticmethod
459+
@classmethod
460+
@cache
454461
def _get_token_signer(
455-
expiration_time_in_seconds: int = conf.getint("api", "auth_jwt_expiration_time"),
456-
) -> JWTSigner:
462+
cls,
463+
expiration_time_in_seconds: int = conf.getint("api_auth", "jwt_expiration_time"),
464+
) -> JWTGenerator:
457465
"""
458466
Return the signer used to sign JWT token.
459467
460468
:meta private:
461469
462470
:param expiration_time_in_seconds: expiration time in seconds of the token
463471
"""
464-
return JWTSigner(
465-
secret_key=get_signing_key("api", "auth_jwt_secret"),
466-
expiration_time_in_seconds=expiration_time_in_seconds,
467-
audience="front-apis",
472+
return JWTGenerator(
473+
**get_signing_args(),
474+
valid_for=expiration_time_in_seconds,
475+
audience=conf.get("api", "jwt_audience", fallback="apache-airflow"),
476+
)
477+
478+
@classmethod
479+
@cache
480+
def _get_token_validator(cls) -> JWTValidator:
481+
"""
482+
Return the signer used to sign JWT token.
483+
484+
:meta private:
485+
"""
486+
return JWTValidator(
487+
**get_sig_validation_args(),
488+
leeway=conf.getint("api_auth", "jwt_leeway"),
489+
audience=conf.get("api_auth", "jwt_audience", fallback="apache-airflow"),
468490
)

airflow/api_fastapi/auth/managers/simple/routes/login.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def create_token_all_admins() -> RedirectResponse:
6262
username="Anonymous",
6363
role="ADMIN",
6464
)
65-
url = urljoin(conf.get("api", "base_url"), f"?token={get_auth_manager().get_jwt_token(user)}")
65+
url = urljoin(conf.get("api", "base_url"), f"?token={get_auth_manager().generate_jwt(user)}")
6666
return RedirectResponse(url=url)
6767

6868

@@ -76,5 +76,5 @@ def create_token_cli(
7676
) -> LoginResponse:
7777
"""Authenticate the user for the CLI."""
7878
return SimpleAuthManagerLogin.create_token(
79-
body=body, expiration_time_in_sec=conf.getint("api", "auth_jwt_cli_expiration_time")
79+
body=body, expiration_time_in_sec=conf.getint("api_auth", "jwt_cli_expiration_time")
8080
)

airflow/api_fastapi/auth/managers/simple/services/login.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class SimpleAuthManagerLogin:
3131

3232
@classmethod
3333
def create_token(
34-
cls, body: LoginBody, expiration_time_in_sec: int = conf.getint("api", "auth_jwt_expiration_time")
34+
cls, body: LoginBody, expiration_time_in_sec: int = conf.getint("api_auth", "jwt_expiration_time")
3535
) -> LoginResponse:
3636
"""
3737
Authenticate user with given configuration.
@@ -67,7 +67,7 @@ def create_token(
6767
)
6868

6969
return LoginResponse(
70-
jwt_token=get_auth_manager().get_jwt_token(
70+
jwt_token=get_auth_manager().generate_jwt(
7171
user=user, expiration_time_in_seconds=expiration_time_in_sec
7272
)
7373
)

airflow/api_fastapi/auth/managers/simple/simple_auth_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,10 @@ def get_url_login(self, **kwargs) -> str:
139139
return AUTH_MANAGER_FASTAPI_APP_PREFIX + "/login"
140140

141141
def deserialize_user(self, token: dict[str, Any]) -> SimpleAuthManagerUser:
142-
return SimpleAuthManagerUser(username=token["username"], role=token["role"])
142+
return SimpleAuthManagerUser(username=token["sub"], role=token["role"])
143143

144144
def serialize_user(self, user: SimpleAuthManagerUser) -> dict[str, Any]:
145-
return {"username": user.username, "role": user.role}
145+
return {"sub": user.username, "role": user.role}
146146

147147
def is_authorized_configuration(
148148
self,

0 commit comments

Comments
 (0)