1
+ import importlib
2
+ import os
3
+
1
4
import pytest
5
+ import social_core .backends as backends
6
+ from fastapi import APIRouter
2
7
from fastapi import FastAPI
3
8
from httpx import AsyncClient
4
9
from social_core .backends .github import GithubOAuth2
10
+ from social_core .backends .oauth import BaseOAuth2
11
+ from starlette .requests import Request
5
12
6
13
from fastapi_oauth2 .client import OAuth2Client
14
+ from fastapi_oauth2 .config import OAuth2Config
15
+ from fastapi_oauth2 .middleware import OAuth2Backend
7
16
from fastapi_oauth2 .middleware import OAuth2Middleware
8
17
from fastapi_oauth2 .router import router as oauth2_router
9
18
10
19
app = FastAPI ()
20
+ router = APIRouter ()
21
+
22
+
23
+ @router .get ("/test_backends" )
24
+ async def _backends (request : Request ):
25
+ responses = []
26
+ for module in os .listdir (backends .__path__ [0 ]):
27
+ try :
28
+ module_instance = importlib .import_module ("social_core.backends.%s" % module [:- 3 ])
29
+ backend_implementations = [
30
+ attr for attr in module_instance .__dict__ .values ()
31
+ if type (attr ) is type and all ([
32
+ issubclass (attr , BaseOAuth2 ),
33
+ attr is not BaseOAuth2 ,
34
+ ])
35
+ ]
36
+ for backend_cls in backend_implementations :
37
+ backend = OAuth2Backend (OAuth2Config (
38
+ clients = [
39
+ OAuth2Client (
40
+ backend = backend_cls ,
41
+ client_id = "test_client_id" ,
42
+ client_secret = "test_client_secret" ,
43
+ )
44
+ ]
45
+ ))
46
+ responses .append (await backend .authenticate (request ))
47
+ except ImportError :
48
+ continue
49
+ return responses
50
+
11
51
52
+ app .include_router (router )
12
53
app .include_router (oauth2_router )
13
54
app .add_middleware (OAuth2Middleware , config = {
14
55
"allow_http" : True ,
@@ -27,3 +68,9 @@ async def test_auth_redirect():
27
68
async with AsyncClient (app = app , base_url = "http://test" ) as client :
28
69
response = await client .get ("/oauth2/github/auth" )
29
70
assert response .status_code == 303 # Redirect
71
+
72
+
73
+ @pytest .mark .anyio
74
+ async def test_backends ():
75
+ async with AsyncClient (app = app , base_url = "http://test" ) as client :
76
+ assert all ((await client .get ("/test_backends" )).json ())
0 commit comments