Skip to content

Commit 0ffac70

Browse files
committed
GH-18: Add tests for the PKCE workflow
1 parent d9a2393 commit 0ffac70

File tree

2 files changed

+41
-10
lines changed

2 files changed

+41
-10
lines changed

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def auth(request: Request):
8383
async def token(request: Request, provider: str):
8484
if request.auth.ssr:
8585
return await request.auth.clients[provider].token_redirect(request, app=get_idp())
86-
return await request.auth.clients[provider].token_data(request)
86+
return await request.auth.clients[provider].token_data(request, app=get_idp())
8787

8888
application.include_router(app_router)
8989
application.include_router(oauth2_router)

tests/test_oauth2.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,45 @@
1+
from urllib.parse import urlencode
2+
13
import pytest
24
from httpx import AsyncClient
5+
from jose.jwt import encode as jwt_encode
6+
from oauthlib.oauth2 import WebApplicationClient
37

48

5-
@pytest.mark.anyio
6-
async def test_oauth2_basic_flow(get_app):
7-
async with AsyncClient(app=get_app(), base_url="http://test") as client:
8-
response = await client.get("/user")
9-
assert response.status_code == 403
10-
response = await client.get("/oauth2/test/auth")
11-
response = await client.get(response.headers.get("location"))
12-
await client.get(response.headers.get("location"))
9+
async def oauth2_basic_workflow(get_app, idp=False, ssr=True, authorize_query="", token_query="", use_header=False):
10+
async with AsyncClient(app=get_app(with_idp=idp, with_ssr=ssr), base_url="http://test") as client:
1311
response = await client.get("/user")
14-
assert response.status_code == 200
12+
assert response.status_code == 403 # Forbidden
13+
14+
response = await client.get("/oauth2/test/authorize" + authorize_query) # Get authorization endpoint
15+
authorization_endpoint = response.headers.get("location") if ssr else response.json().get("url")
16+
response = await client.get(authorization_endpoint) # Authorize
17+
response = await client.get(response.headers.get("location") + token_query) # Obtain token
18+
19+
response = await client.get("/user", headers=dict(
20+
Authorization=jwt_encode(response.json(), "") # Set token
21+
) if use_header else None)
22+
assert response.status_code == 200 # OK
23+
24+
25+
@pytest.mark.anyio
26+
async def test_oauth2_basic_workflow(get_app):
27+
await oauth2_basic_workflow(get_app, idp=True)
28+
await oauth2_basic_workflow(get_app, idp=True, ssr=False, use_header=True)
29+
30+
31+
@pytest.mark.anyio
32+
async def test_oauth2_pkce_workflow(get_app):
33+
for code_challenge_method in (None, "S256"):
34+
# Generate the code verifier and challenge
35+
oauth_client = WebApplicationClient("test_id")
36+
code_verifier = oauth_client.create_code_verifier(128)
37+
code_challenge = oauth_client.create_code_challenge(code_verifier, code_challenge_method)
38+
39+
aq = dict(code_challenge=code_challenge)
40+
if code_challenge_method:
41+
aq["code_challenge_method"] = code_challenge_method
42+
aq = "?" + urlencode(aq)
43+
tq = "&" + urlencode(dict(code_verifier=code_verifier))
44+
await oauth2_basic_workflow(get_app, idp=True, authorize_query=aq, token_query=tq)
45+
await oauth2_basic_workflow(get_app, idp=True, ssr=False, authorize_query=aq, token_query=tq, use_header=True)

0 commit comments

Comments
 (0)