Skip to content

Commit 70afb69

Browse files
authored
Merge pull request #10 from python-ellar/fixe_override_jwt_config
fixed jwt config override bug
2 parents 37e0eb1 + 038ec98 commit 70afb69

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

ellar_jwt/services.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import typing as t
23
from datetime import timedelta
34

@@ -75,7 +76,10 @@ async def sign_async(
7576
headers: t.Optional[t.Dict[str, t.Any]] = None,
7677
**jwt_config: t.Any,
7778
) -> str:
78-
return await anyio.to_thread.run_sync(self.sign, payload, headers, **jwt_config)
79+
func = self.sign
80+
if jwt_config:
81+
func = functools.partial(self.sign, **jwt_config)
82+
return await anyio.to_thread.run_sync(func, payload, headers)
7983

8084
def decode(
8185
self, token: str, verify: bool = True, **jwt_config: t.Any
@@ -109,4 +113,7 @@ def decode(
109113
async def decode_async(
110114
self, token: str, verify: bool = True, **jwt_config: t.Any
111115
) -> t.Dict[str, t.Any]:
112-
return await anyio.to_thread.run_sync(self.decode, token, verify, **jwt_config)
116+
func = self.decode
117+
if jwt_config:
118+
func = functools.partial(self.decode, **jwt_config)
119+
return await anyio.to_thread.run_sync(func, token, verify)

tests/test_jwt_service.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,3 +471,18 @@ async def test_sign_async_and_decode_async(self):
471471
token = await backend.sign_async(self.payload)
472472
decoded = await backend.decode_async(token)
473473
assert decoded["uuid"] == str(unique)
474+
475+
@pytest.mark.asyncio
476+
async def test_sign_async_with_override_jwt_config(self):
477+
backend = JWTService(
478+
JWTConfiguration(
479+
algorithm="HS256",
480+
signing_secret_key=SECRET,
481+
json_encoder=UUIDJSONEncoder,
482+
)
483+
)
484+
unique = uuid.uuid4()
485+
self.payload["uuid"] = unique
486+
token = await backend.sign_async(self.payload, algorithm="HS384")
487+
decoded = await backend.decode_async(token, algorithm="HS384")
488+
assert decoded["uuid"] == str(unique)

0 commit comments

Comments
 (0)