|
| 1 | +from urllib.parse import urlencode |
| 2 | + |
1 | 3 | import pytest
|
2 | 4 | from httpx import AsyncClient
|
| 5 | +from jose.jwt import encode as jwt_encode |
| 6 | +from oauthlib.oauth2 import WebApplicationClient |
3 | 7 |
|
4 | 8 |
|
5 |
| -@pytest.mark.anyio |
6 |
| -async def test_oauth2_basic_flow(get_app): |
7 |
| - async with AsyncClient(app=get_app(), base_url="http://test") as client: |
8 |
| - response = await client.get("/user") |
9 |
| - assert response.status_code == 403 |
10 |
| - response = await client.get("/oauth2/test/auth") |
11 |
| - response = await client.get(response.headers.get("location")) |
12 |
| - await client.get(response.headers.get("location")) |
| 9 | +async def oauth2_basic_workflow(get_app, idp=False, ssr=True, authorize_query="", token_query="", use_header=False): |
| 10 | + async with AsyncClient(app=get_app(with_idp=idp, with_ssr=ssr), base_url="http://test") as client: |
13 | 11 | response = await client.get("/user")
|
14 |
| - assert response.status_code == 200 |
| 12 | + assert response.status_code == 403 # Forbidden |
| 13 | + |
| 14 | + response = await client.get("/oauth2/test/authorize" + authorize_query) # Get authorization endpoint |
| 15 | + authorization_endpoint = response.headers.get("location") if ssr else response.json().get("url") |
| 16 | + response = await client.get(authorization_endpoint) # Authorize |
| 17 | + response = await client.get(response.headers.get("location") + token_query) # Obtain token |
| 18 | + |
| 19 | + response = await client.get("/user", headers=dict( |
| 20 | + Authorization=jwt_encode(response.json(), "") # Set token |
| 21 | + ) if use_header else None) |
| 22 | + assert response.status_code == 200 # OK |
| 23 | + |
| 24 | + |
| 25 | +@pytest.mark.anyio |
| 26 | +async def test_oauth2_basic_workflow(get_app): |
| 27 | + await oauth2_basic_workflow(get_app, idp=True) |
| 28 | + await oauth2_basic_workflow(get_app, idp=True, ssr=False, use_header=True) |
| 29 | + |
| 30 | + |
| 31 | +@pytest.mark.anyio |
| 32 | +async def test_oauth2_pkce_workflow(get_app): |
| 33 | + for code_challenge_method in (None, "S256"): |
| 34 | + # Generate the code verifier and challenge |
| 35 | + oauth_client = WebApplicationClient("test_id") |
| 36 | + code_verifier = oauth_client.create_code_verifier(128) |
| 37 | + code_challenge = oauth_client.create_code_challenge(code_verifier, code_challenge_method) |
| 38 | + |
| 39 | + aq = dict(code_challenge=code_challenge) |
| 40 | + if code_challenge_method: |
| 41 | + aq["code_challenge_method"] = code_challenge_method |
| 42 | + aq = "?" + urlencode(aq) |
| 43 | + tq = "&" + urlencode(dict(code_verifier=code_verifier)) |
| 44 | + await oauth2_basic_workflow(get_app, idp=True, authorize_query=aq, token_query=tq) |
| 45 | + await oauth2_basic_workflow(get_app, idp=True, ssr=False, authorize_query=aq, token_query=tq, use_header=True) |
0 commit comments