Skip to content

Commit 6804b8a

Browse files
s-aleshinprovinzkraut
authored andcommitted
feat(jwt): [#4191] extend Token.encode() to support custom headers (#4192)
1 parent 27bc140 commit 6804b8a

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

litestar/security/jwt/token.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __post_init__(self) -> None:
9090
def decode_payload(
9191
cls,
9292
encoded_token: str,
93-
secret: str,
93+
secret: str | bytes,
9494
algorithms: list[str],
9595
issuer: list[str] | None = None,
9696
audience: str | Sequence[str] | None = None,
@@ -110,7 +110,7 @@ def decode_payload(
110110
def decode(
111111
cls,
112112
encoded_token: str,
113-
secret: str,
113+
secret: str | bytes,
114114
algorithm: str,
115115
audience: str | Sequence[str] | None = None,
116116
issuer: str | Sequence[str] | None = None,
@@ -194,12 +194,18 @@ def decode(
194194
) as e:
195195
raise NotAuthorizedException("Invalid token") from e
196196

197-
def encode(self, secret: str, algorithm: str) -> str:
197+
def encode(
198+
self,
199+
secret: str | bytes,
200+
algorithm: str,
201+
headers: dict[str, Any] | None = None,
202+
) -> str:
198203
"""Encode the token instance into a string.
199204
200205
Args:
201206
secret: The secret with which the JWT is encoded.
202207
algorithm: The algorithm used to encode the JWT.
208+
headers: Optional headers to include in the JWT (e.g., {"kid": "..."}).
203209
204210
Returns:
205211
An encoded token string.
@@ -212,6 +218,7 @@ def encode(self, secret: str, algorithm: str) -> str:
212218
payload={k: v for k, v in asdict(self).items() if v is not None},
213219
key=secret,
214220
algorithm=algorithm,
221+
headers=headers,
215222
)
216223
except (jwt.DecodeError, NotImplementedError) as e:
217224
raise ImproperlyConfiguredException("Failed to encode token") from e

tests/unit/test_security/test_jwt/test_token.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ class CustomToken(Token):
206206
def decode_payload(
207207
cls,
208208
encoded_token: str,
209-
secret: str,
209+
secret: str | bytes,
210210
algorithms: list[str],
211211
issuer: list[str] | None = None,
212212
audience: str | Sequence[str] | None = None,
@@ -223,3 +223,14 @@ def decode_payload(
223223
_secret = secrets.token_hex()
224224
encoded = CustomToken(exp=datetime.now() + timedelta(days=1), sub="foo").encode(_secret, "HS256")
225225
assert CustomToken.decode(encoded, secret=_secret, algorithm="HS256").sub == "some-random-value"
226+
227+
228+
def test_token_encode_includes_custom_headers() -> None:
229+
token = Token(exp=datetime.now() + timedelta(days=1), sub="some-random-value")
230+
custom_headers = {"kid": "key-id"}
231+
encoded = token.encode(secret=secrets.token_hex(), algorithm="HS256", headers=custom_headers)
232+
header = jwt.get_unverified_header(encoded)
233+
234+
assert header["alg"] == "HS256"
235+
assert "kid" in header
236+
assert header["kid"] == custom_headers["kid"]

0 commit comments

Comments
 (0)