Skip to content

Commit 7e8ab42

Browse files
authored
Merge pull request #13 from eadwinCode/guard_test
Guard Module Testing
2 parents 9608edc + 3a6ad58 commit 7e8ab42

File tree

6 files changed

+245
-31
lines changed

6 files changed

+245
-31
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ You will see the automatic interactive API documentation (provided by <a href="h
107107
Project is still in development
108108
- Remaining testing modules:
109109
- configuration
110-
- guard
111110
- Project CLI scaffolding
112111
- Documentation
113112
- Database Plugin with [Encode/ORM](https://github.com/encode/orm)

ellar/core/guard/base.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,32 @@
33

44
from pydantic import BaseModel
55
from starlette.exceptions import HTTPException
6-
from starlette.status import HTTP_403_FORBIDDEN
6+
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
77

88
from ellar.core.connection import HTTPConnection
99
from ellar.core.context import ExecutionContext
1010
from ellar.exceptions import APIException
1111

1212

1313
class GuardCanActivate(ABC, metaclass=ABCMeta):
14-
_exception_class: t.Type[HTTPException] = HTTPException
15-
_status_code: int = HTTP_403_FORBIDDEN
16-
_detail: str = "Not authenticated"
14+
exception_class: t.Type[HTTPException] = HTTPException
15+
status_code: int = HTTP_403_FORBIDDEN
16+
detail: str = "Not authenticated"
1717

1818
@abstractmethod
1919
async def can_activate(self, context: ExecutionContext) -> bool:
2020
pass
2121

2222
def raise_exception(self) -> None:
23-
raise self._exception_class(status_code=self._status_code, detail=self._detail)
23+
raise self.exception_class(status_code=self.status_code, detail=self.detail)
2424

2525

2626
class BaseAuthGuard(GuardCanActivate, ABC, metaclass=ABCMeta):
27+
status_code = HTTP_401_UNAUTHORIZED
2728
openapi_scope: t.List = []
29+
openapi_in: t.Optional[str] = None
30+
openapi_description: t.Optional[str] = None
31+
openapi_name: t.Optional[str] = None
2832

2933
@abstractmethod
3034
async def handle_request(self, *, connection: HTTPConnection) -> t.Optional[t.Any]:
@@ -40,6 +44,7 @@ async def can_activate(self, context: ExecutionContext) -> bool:
4044
result = await self.handle_request(connection=connection)
4145
if result:
4246
# auth parameter on request
47+
connection.scope["user"] = result
4348
return True
4449
return False
4550

@@ -55,9 +60,8 @@ class HTTPAuthorizationCredentials(BaseModel):
5560

5661

5762
class BaseAPIKey(BaseAuthGuard, ABC, metaclass=ABCMeta):
58-
openapi_in: t.Optional[str] = None
63+
exception_class = APIException
5964
parameter_name: str = "key"
60-
openapi_description: t.Optional[str] = None
6165

6266
def __init__(self) -> None:
6367
self.name = self.parameter_name
@@ -66,9 +70,7 @@ def __init__(self) -> None:
6670
async def handle_request(self, connection: HTTPConnection) -> t.Optional[t.Any]:
6771
key = self._get_key(connection)
6872
if not key:
69-
raise APIException(
70-
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
71-
)
73+
self.raise_exception()
7274
return await self.authenticate(connection, key)
7375

7476
@abstractmethod
@@ -88,12 +90,11 @@ def get_guard_scheme(cls) -> t.Dict:
8890
"type": "apiKey",
8991
"description": cls.openapi_description,
9092
"in": cls.openapi_in,
91-
"name": cls.__name__,
93+
"name": cls.openapi_name or cls.__name__,
9294
}
9395

9496

9597
class BaseHttpAuth(BaseAuthGuard, ABC, metaclass=ABCMeta):
96-
openapi_description: t.Optional[str] = None
9798
openapi_scheme: t.Optional[str] = None
9899
realm: t.Optional[str] = None
99100

@@ -130,5 +131,5 @@ def get_guard_scheme(cls) -> t.Dict:
130131
"type": "http",
131132
"description": cls.openapi_description,
132133
"scheme": cls.openapi_scheme,
133-
"name": cls.__name__,
134+
"name": cls.openapi_name or cls.__name__,
134135
}

ellar/core/guard/http.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@
33
from abc import ABC
44
from base64 import b64decode
55

6-
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
7-
86
from ellar.core.connection import HTTPConnection
97
from ellar.exceptions import APIException, AuthenticationFailed
108

119
from .base import BaseHttpAuth, HTTPAuthorizationCredentials, HTTPBasicCredentials
1210

1311

1412
class HttpBearerAuth(BaseHttpAuth, ABC):
13+
exception_class = APIException
1514
openapi_scheme: str = "bearer"
1615
openapi_bearer_format: t.Optional[str] = None
1716
header: str = "Authorization"
@@ -28,18 +27,17 @@ def _get_credentials(
2827
authorization: str = connection.headers.get(self.header)
2928
scheme, _, credentials = self.authorization_partitioning(authorization)
3029
if not (authorization and scheme and credentials):
31-
raise APIException(
32-
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
33-
)
34-
if scheme.lower() != self.openapi_scheme:
35-
raise APIException(
36-
status_code=HTTP_403_FORBIDDEN,
30+
self.raise_exception()
31+
if scheme and str(scheme).lower() != self.openapi_scheme:
32+
raise self.exception_class(
33+
status_code=self.status_code,
3734
detail="Invalid authentication credentials",
3835
)
3936
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
4037

4138

4239
class HttpBasicAuth(BaseHttpAuth, ABC):
40+
exception_class = APIException
4341
openapi_scheme: str = "basic"
4442
realm: t.Optional[str] = None
4543
header = "Authorization"
@@ -50,33 +48,39 @@ def _not_unauthorized_exception(self, message: str) -> None:
5048
else:
5149
unauthorized_headers = {"WWW-Authenticate": "Basic"}
5250
raise AuthenticationFailed(
53-
status_code=HTTP_401_UNAUTHORIZED,
51+
status_code=self.status_code,
5452
detail=message,
5553
headers=unauthorized_headers,
5654
)
5755

5856
def _get_credentials(self, connection: HTTPConnection) -> HTTPBasicCredentials:
5957
authorization: str = connection.headers.get(self.header)
60-
scheme, _, credentials = self.authorization_partitioning(authorization)
58+
parts = authorization.split(" ") if authorization else []
59+
scheme, credentials = str(), str()
60+
61+
if len(parts) == 1:
62+
credentials = parts[0]
63+
scheme = "basic"
64+
elif len(parts) == 2:
65+
credentials = parts[1]
66+
scheme = parts[0].lower()
6167

6268
if (
6369
not (authorization and scheme and credentials)
6470
or scheme.lower() != self.openapi_scheme
6571
):
66-
raise APIException(
67-
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
68-
)
72+
self.raise_exception()
73+
6974
data: t.Optional[t.Union[str, bytes]] = None
7075
try:
7176
data = b64decode(credentials).decode("ascii")
7277
except (ValueError, UnicodeDecodeError, binascii.Error):
7378
self._not_unauthorized_exception("Invalid authentication credentials")
7479

7580
username, separator, password = (
76-
str(data).partition(":") if data else None,
77-
None,
78-
None,
81+
str(data).partition(":") if data else (None, None, None)
7982
)
83+
8084
if not separator:
8185
self._not_unauthorized_exception("Invalid authentication credentials")
8286
return HTTPBasicCredentials(username=username, password=password)

tests/notstarted/test_guard/__init__.py

Whitespace-only changes.

tests/test_guard.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import pytest
2+
from starlette.status import HTTP_401_UNAUTHORIZED
3+
4+
from ellar.common import Req, get, guards
5+
from ellar.core import AppFactory, TestClient
6+
from ellar.core.guard import (
7+
APIKeyCookie,
8+
APIKeyHeader,
9+
APIKeyQuery,
10+
HttpBasicAuth,
11+
HttpBearerAuth,
12+
HttpDigestAuth,
13+
)
14+
from ellar.exceptions import APIException
15+
from ellar.openapi import OpenAPIDocumentBuilder
16+
from ellar.serializer import serialize_object
17+
18+
19+
class CustomException(APIException):
20+
pass
21+
22+
23+
class QuerySecretKey(APIKeyQuery):
24+
async def authenticate(self, connection, key):
25+
if key == "querysecretkey":
26+
return key
27+
28+
29+
class HeaderSecretKey(APIKeyHeader):
30+
async def authenticate(self, connection, key):
31+
if key == "headersecretkey":
32+
return key
33+
34+
35+
class HeaderSecretKeyCustomException(HeaderSecretKey):
36+
exception_class = CustomException
37+
38+
39+
class CookieSecretKey(APIKeyCookie):
40+
openapi_name = "API Key Auth"
41+
42+
async def authenticate(self, connection, key):
43+
if key == "cookiesecretkey":
44+
return key
45+
46+
47+
class BasicAuth(HttpBasicAuth):
48+
openapi_name = "API Authentication"
49+
50+
async def authenticate(self, connection, credentials):
51+
if credentials.username == "admin" and credentials.password == "secret":
52+
return credentials.username
53+
54+
55+
class BearerAuth(HttpBearerAuth):
56+
openapi_name = "JWT Authentication"
57+
58+
async def authenticate(self, connection, credentials):
59+
if credentials.credentials == "bearertoken":
60+
return credentials.credentials
61+
62+
63+
class DigestAuth(HttpDigestAuth):
64+
async def authenticate(self, connection, credentials):
65+
if credentials.credentials == "digesttoken":
66+
return credentials.credentials
67+
68+
69+
app = AppFactory.create_app()
70+
71+
72+
for _path, auth in [
73+
("apikeyquery", QuerySecretKey()),
74+
("apikeyheader", HeaderSecretKey()),
75+
("apikeycookie", CookieSecretKey()),
76+
("basic", BasicAuth()),
77+
("bearer", BearerAuth()),
78+
("digest", DigestAuth()),
79+
("customexception", HeaderSecretKeyCustomException()),
80+
]:
81+
82+
@get(f"/{_path}")
83+
@guards(auth)
84+
def auth_demo_endpoint(request: Req()):
85+
return {"authentication": request.user}
86+
87+
app.router.append(auth_demo_endpoint)
88+
89+
client = TestClient(app)
90+
91+
BODY_UNAUTHORIZED_DEFAULT = {"detail": "Not authenticated"}
92+
93+
94+
@pytest.mark.parametrize(
95+
"path,kwargs,expected_code,expected_body",
96+
[
97+
("/apikeyquery", {}, HTTP_401_UNAUTHORIZED, BODY_UNAUTHORIZED_DEFAULT),
98+
(
99+
"/apikeyquery?key=querysecretkey",
100+
{},
101+
200,
102+
dict(authentication="querysecretkey"),
103+
),
104+
("/apikeyheader", {}, HTTP_401_UNAUTHORIZED, BODY_UNAUTHORIZED_DEFAULT),
105+
(
106+
"/apikeyheader",
107+
dict(headers={"key": "headersecretkey"}),
108+
200,
109+
dict(authentication="headersecretkey"),
110+
),
111+
("/apikeycookie", {}, HTTP_401_UNAUTHORIZED, BODY_UNAUTHORIZED_DEFAULT),
112+
(
113+
"/apikeycookie",
114+
dict(cookies={"key": "cookiesecretkey"}),
115+
200,
116+
dict(authentication="cookiesecretkey"),
117+
),
118+
("/basic", {}, HTTP_401_UNAUTHORIZED, BODY_UNAUTHORIZED_DEFAULT),
119+
(
120+
"/basic",
121+
dict(headers={"Authorization": "Basic YWRtaW46c2VjcmV0"}),
122+
200,
123+
dict(authentication="admin"),
124+
),
125+
(
126+
"/basic",
127+
dict(headers={"Authorization": "YWRtaW46c2VjcmV0"}),
128+
200,
129+
dict(authentication="admin"),
130+
),
131+
(
132+
"/basic",
133+
dict(headers={"Authorization": "Basic invalid"}),
134+
HTTP_401_UNAUTHORIZED,
135+
{"detail": "Invalid authentication credentials"},
136+
),
137+
(
138+
"/basic",
139+
dict(headers={"Authorization": "some invalid value"}),
140+
HTTP_401_UNAUTHORIZED,
141+
BODY_UNAUTHORIZED_DEFAULT,
142+
),
143+
("/bearer", {}, 401, BODY_UNAUTHORIZED_DEFAULT),
144+
(
145+
"/bearer",
146+
dict(headers={"Authorization": "Bearer bearertoken"}),
147+
200,
148+
dict(authentication="bearertoken"),
149+
),
150+
(
151+
"/bearer",
152+
dict(headers={"Authorization": "Invalid bearertoken"}),
153+
HTTP_401_UNAUTHORIZED,
154+
{"detail": "Invalid authentication credentials"},
155+
),
156+
("/digest", {}, 401, BODY_UNAUTHORIZED_DEFAULT),
157+
(
158+
"/digest",
159+
dict(headers={"Authorization": "Digest digesttoken"}),
160+
200,
161+
dict(authentication="digesttoken"),
162+
),
163+
(
164+
"/digest",
165+
dict(headers={"Authorization": "Invalid digesttoken"}),
166+
HTTP_401_UNAUTHORIZED,
167+
{"detail": "Invalid authentication credentials"},
168+
),
169+
("/customexception", {}, HTTP_401_UNAUTHORIZED, BODY_UNAUTHORIZED_DEFAULT),
170+
(
171+
"/customexception",
172+
dict(headers={"key": "headersecretkey"}),
173+
200,
174+
dict(authentication="headersecretkey"),
175+
),
176+
],
177+
)
178+
def test_auth(path, kwargs, expected_code, expected_body):
179+
response = client.get(path, **kwargs)
180+
assert response.status_code == expected_code
181+
assert response.json() == expected_body
182+
183+
184+
def test_auth_schema():
185+
document = serialize_object(OpenAPIDocumentBuilder().build_document(app))
186+
assert document["components"]["securitySchemes"] == {
187+
"API Key Auth": {"type": "apiKey", "in": "cookie", "name": "API Key Auth"},
188+
"HeaderSecretKey": {
189+
"type": "apiKey",
190+
"in": "header",
191+
"name": "HeaderSecretKey",
192+
},
193+
"QuerySecretKey": {"type": "apiKey", "in": "query", "name": "QuerySecretKey"},
194+
"API Authentication": {
195+
"type": "http",
196+
"scheme": "basic",
197+
"name": "API Authentication",
198+
},
199+
"JWT Authentication": {
200+
"type": "http",
201+
"scheme": "bearer",
202+
"name": "JWT Authentication",
203+
},
204+
"HeaderSecretKeyCustomException": {
205+
"type": "apiKey",
206+
"in": "header",
207+
"name": "HeaderSecretKeyCustomException",
208+
},
209+
"DigestAuth": {"type": "http", "scheme": "digest", "name": "DigestAuth"},
210+
}

0 commit comments

Comments
 (0)