Skip to content

Commit bb349b4

Browse files
committed
WIP
Signed-off-by: David Black <[email protected]>
1 parent 95945db commit bb349b4

File tree

22 files changed

+113
-82
lines changed

22 files changed

+113
-82
lines changed

atlassian_jwt_auth/auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,4 @@ def create(cls, issuer: str, key_identifier: Union[KeyIdentifier, str], private_
2626

2727
def _get_header_value(self) -> bytes:
2828
return b'Bearer ' + self._signer.generate_jwt(
29-
self._audience, additional_claims=self._additional_claims)
29+
self._audience, additional_claims=self._additional_claims).encode("utf-8")

atlassian_jwt_auth/contrib/aiohttp/auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def encode(self) -> str:
1919

2020

2121
def create_jwt_auth(
22-
issuer: str, key_identifier: Union[KeyIdentifier, str], private_key_pem: str, audience: str, **kwargs: Any) -> JWTAuth:
22+
issuer: str, key_identifier: Union[KeyIdentifier, str], private_key_pem: str, audience: str, **kwargs: Any) -> BaseJWTAuth:
2323
"""Instantiate a JWTAuth while creating the signer inline"""
2424
return JWTAuth.create(
2525
issuer, key_identifier, private_key_pem, audience, **kwargs)

atlassian_jwt_auth/contrib/aiohttp/key.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import urllib.parse
33
from asyncio import AbstractEventLoop
4-
from typing import Any, Dict, Optional
4+
from typing import Any, Dict, Optional, Awaitable
55

66
import aiohttp
77

@@ -22,7 +22,7 @@ def __init__(self, base_url: str, *,
2222
self.loop = loop
2323
super().__init__(base_url)
2424

25-
def _get_session(self) -> aiohttp.ClientSession:
25+
def _get_session(self) -> aiohttp.ClientSession: # type: ignore[override]
2626
if HTTPSPublicKeyRetriever._class_session is None:
2727
HTTPSPublicKeyRetriever._class_session = aiohttp.ClientSession(
2828
loop=self.loop)
@@ -43,11 +43,11 @@ def _convert_proxies_to_proxy_arg(
4343
return requests_kwargs
4444

4545
async def _retrieve(
46-
self, url: str, requests_kwargs: Dict[Any, Any]) -> str:
46+
self, url: str, requests_kwargs: Dict[Any, Any]) -> Awaitable[str]:
4747
requests_kwargs = self._convert_proxies_to_proxy_arg(
4848
url, requests_kwargs)
4949
try:
50-
resp = await self._session.get(url, headers={'accept':
50+
resp = await self._session.get(url, headers={'accept': # type: ignore[misc]
5151
PEM_FILE_TYPE},
5252
**requests_kwargs)
5353
resp.raise_for_status()

atlassian_jwt_auth/contrib/aiohttp/verifier.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import asyncio
2-
from typing import Any, Dict
2+
from typing import Any, Dict, Coroutine, Union
33

44
import jwt
55

66
from atlassian_jwt_auth import key
77
from atlassian_jwt_auth.verifier import JWTAuthVerifier as _JWTAuthVerifier
88

99

10-
class JWTAuthVerifier(_JWTAuthVerifier):
11-
async def verify_jwt(self, a_jwt: str, audience: str,
10+
class JWTAuthVerifier(_JWTAuthVerifier): # type: ignore[override]
11+
async def verify_jwt(self, a_jwt: str, audience: str, # type: ignore[override]
1212
leeway: int = 0, **requests_kwargs: Any) -> Dict[Any, Any]:
1313
"""Verify if the token is correct
1414

atlassian_jwt_auth/contrib/django/middleware.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def process_request(self, request) -> Optional[str]:
4949
request.META['HTTP_AUTHORIZATION'] = orig_auth
5050
if asap_auth is not None:
5151
request.META[self.xauth] = asap_auth
52+
return None
5253

5354
def process_view(self, request: Any, view_func: Callable,
5455
view_args: Any, view_kwargs: Any) -> None:

atlassian_jwt_auth/contrib/requests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
class JWTAuth(AuthBase, BaseJWTAuth):
1313
"""Adds a JWT bearer token to the request per the ASAP specification"""
1414

15-
def __call__(self, r: requests.Request):
16-
r.headers['Authorization'] = self._get_header_value()
15+
def __call__(self, r: requests.models.PreparedRequest) -> requests.models.PreparedRequest:
16+
r.headers['Authorization'] = self._get_header_value() # type: ignore[assignment]
1717
return r
1818

1919

atlassian_jwt_auth/contrib/tests/aiohttp/test_auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class BaseAuthTest(test_requests.BaseRequestsTest):
1313
def _get_auth_header(self, auth) -> bytes:
1414
return auth.encode().encode('latin1')
1515

16-
def create_jwt_auth(self, *args: Any, **kwargs: Dict):
16+
def create_jwt_auth(self, *args: Any, **kwargs: Any):
1717
return create_jwt_auth(*args, **kwargs)
1818

1919

atlassian_jwt_auth/contrib/tests/aiohttp/test_public_key_provider.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ class DummyHTTPSPublicKeyRetriever(HTTPSPublicKeyRetriever):
2323
def set_headers(self, headers) -> None:
2424
self._session.get.return_value.headers.update(headers)
2525

26-
def set_text(self, text: str) -> None:
27-
self._session.get.return_value.text.return_value = text
26+
def set_text(self, text: str | bytes) -> None:
27+
self._session.get.return_value.text.return_value = text # type: ignore[attr-defined]
2828

2929
def _get_session(self) -> Mock:
3030
session = Mock(spec=aiohttp.ClientSession)
@@ -41,16 +41,16 @@ class BaseHTTPSPublicKeyRetrieverTestMixin(object):
4141
"""Tests for aiohttp.HTTPSPublicKeyRetriever class for RS256 algorithm"""
4242

4343
def setUp(self) -> None:
44-
self._private_key_pem = self.get_new_private_key_in_pem_format()
44+
self._private_key_pem = self.get_new_private_key_in_pem_format() # type: ignore[attr-defined]
4545
self._public_key_pem = utils.get_public_key_pem_for_private_key_pem(
4646
self._private_key_pem)
4747
self.base_url = 'https://example.com'
4848

4949
async def test_retrieve(self) -> None:
5050
"""Check if retrieve method returns public key"""
5151
retriever = DummyHTTPSPublicKeyRetriever(self.base_url)
52-
retriever.set_text(self._public_key_pem)
53-
self.assertEqual(
52+
retriever.set_text(self._public_key_pem) # type: ignore[misc]
53+
self.assertEqual( # type: ignore[attr-defined]
5454
await retriever.retrieve('example/eg'),
5555
self._public_key_pem)
5656

@@ -61,7 +61,7 @@ async def test_retrieve_with_charset_in_content_type_h(self) -> None:
6161
retriever.set_text(self._public_key_pem)
6262
retriever.set_headers(headers)
6363

64-
self.assertEqual(
64+
self.assertEqual( # type: ignore[attr-defined]
6565
await retriever.retrieve('example/eg'),
6666
self._public_key_pem)
6767

@@ -71,10 +71,10 @@ async def test_retrieve_fails_with_different_content_type(self) -> None:
7171
"""
7272
headers = {'content-type': 'different/not-supported'}
7373
retriever = DummyHTTPSPublicKeyRetriever(self.base_url)
74-
retriever.set_text(self._public_key_pem)
74+
retriever.set_text(self._public_key_pem) # type: ignore[arg-type]
7575
retriever.set_headers(headers)
7676

77-
with self.assertRaises(ValueError):
77+
with self.assertRaises(ValueError): # type: ignore[attr-defined]
7878
await retriever.retrieve('example/eg')
7979

8080
async def test_retrieve_session_uses_env_proxy(self) -> None:
@@ -87,9 +87,9 @@ async def test_retrieve_session_uses_env_proxy(self) -> None:
8787
proxy_location)
8888
with mock.patch.dict(os.environ, proxy_dict, clear=True):
8989
retriever = DummyHTTPSPublicKeyRetriever(self.base_url)
90-
self.assertEqual(retriever._proxies, expected_proxies)
90+
self.assertEqual(retriever._proxies, expected_proxies) # type: ignore[attr-defined]
9191
await retriever.retrieve(key_id)
92-
retriever._session.get.assert_called_once_with(
92+
retriever._session.get.assert_called_once_with( #type: ignore[attr-defined]
9393
f'{self.base_url}/{key_id}', headers={'accept': PEM_FILE_TYPE},
9494
proxy=expected_proxies[self.base_url.split(':')[0]]
9595
)

atlassian_jwt_auth/contrib/tests/test_requests.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22
from datetime import timedelta
3-
from typing import Any
3+
from typing import Any, Type
44

55
import jwt
66
from requests import Request
@@ -14,7 +14,7 @@
1414
class BaseRequestsTest(object):
1515

1616
""" tests for the contrib.requests.JWTAuth class """
17-
auth_cls = JWTAuth
17+
auth_cls: Type[BaseJWTAuth] = JWTAuth
1818

1919
def setUp(self) -> None:
2020
self._private_key_pem = self.get_new_private_key_in_pem_format()
@@ -33,7 +33,7 @@ def assert_authorization_header_is_valid(self, auth) -> Any:
3333
return jwt.decode(bearer, self._public_key_pem.decode(),
3434
audience='audience', algorithms=algorithms)
3535

36-
def _get_auth_header(self, auth) -> str:
36+
def _get_auth_header(self, auth) -> bytes:
3737
request = auth(Request())
3838
auth_header = request.headers['Authorization']
3939
return auth_header

atlassian_jwt_auth/contrib/tests/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
1-
from typing import Any, Dict
1+
from typing import Any, Dict, Type
22

33
import requests
44

55
import atlassian_jwt_auth
66
from atlassian_jwt_auth import JWTAuthVerifier
7+
from atlassian_jwt_auth.key import BasePublicKeyRetriever
78

89

9-
def get_static_retriever_class(keys: Dict[str, Any]):
10+
def get_static_retriever_class(keys: Dict[str, Any]) -> Type[BasePublicKeyRetriever]:
1011

11-
class StaticPublicKeyRetriever(object):
12+
class StaticPublicKeyRetriever(BasePublicKeyRetriever):
1213
""" Retrieves a key from a static dict of public keys
1314
(for use in tests only) """
1415

1516
def __init__(self, *args: Any, **
16-
kwargs: Any) -> requests.PreparedRequest:
17+
kwargs: Any) -> None:
1718
self.keys: Dict[str, Any] = keys
1819

19-
def retrieve(self, key_identifier, **requests_kwargs) -> str:
20+
def retrieve(self, key_identifier, **requests_kwargs) -> Any:
2021
return self.keys[key_identifier.key_id]
2122

2223
return StaticPublicKeyRetriever

0 commit comments

Comments
 (0)