1
1
import json
2
2
import os
3
+ import re
3
4
from typing import Any , Dict , List , Optional
4
5
5
6
import httpx
@@ -18,17 +19,18 @@ class SSOLoginError(HTTPException):
18
19
class SSOBase :
19
20
"""Base class (mixin) for all SSO providers"""
20
21
21
- provider : str = NotImplemented
22
- client_id : str = NotImplemented
23
- client_secret : str = NotImplemented
24
- redirect_uri : Optional [str ] = NotImplemented
25
- scope : List [str ] = NotImplemented
22
+ client_id : str = None
23
+ client_secret : str = None
24
+ redirect_uri : Optional [str ] = None
25
+ allow_insecure_http : bool = False
26
+ scope : Optional [List [str ]] = None
27
+ state : Optional [str ] = None
26
28
_oauth_client : Optional [WebApplicationClient ] = None
27
29
additional_headers : Optional [Dict [str , Any ]] = None
28
30
29
- authorization_endpoint : str = NotImplemented
30
- token_endpoint : str = NotImplemented
31
- userinfo_endpoint : str = NotImplemented
31
+ authorization_endpoint : str = None
32
+ token_endpoint : str = None
33
+ userinfo_endpoint : str = None
32
34
33
35
def __init__ (
34
36
self ,
@@ -45,12 +47,9 @@ def __init__(
45
47
if allow_insecure_http :
46
48
os .environ ["OAUTHLIB_INSECURE_TRANSPORT" ] = "1"
47
49
self .scope = scope or self .scope
48
- self .state : Optional [str ] = None
49
50
50
51
@property
51
52
def oauth_client (self ) -> WebApplicationClient :
52
- if self .client_id == NotImplemented :
53
- raise NotImplementedError (f"Provider { self .provider } not supported" )
54
53
if self ._oauth_client is None :
55
54
self ._oauth_client = WebApplicationClient (self .client_id )
56
55
return self ._oauth_client
@@ -63,10 +62,6 @@ def access_token(self) -> Optional[str]:
63
62
def refresh_token (self ) -> Optional [str ]:
64
63
return self .oauth_client .refresh_token
65
64
66
- @classmethod
67
- async def openid_from_response (cls , response : dict ) -> dict :
68
- raise NotImplementedError (f"Provider { cls .provider } not supported" )
69
-
70
65
async def get_login_url (
71
66
self ,
72
67
* ,
@@ -101,58 +96,36 @@ async def verify_and_process(
101
96
headers : Optional [Dict [str , Any ]] = None ,
102
97
redirect_uri : Optional [str ] = None ,
103
98
) -> Optional [dict ]:
104
- headers = headers or {}
105
- code = request .query_params .get ("code" )
106
- if code is None :
99
+ params = params or {}
100
+ additional_headers = headers or {}
101
+ additional_headers .update (self .additional_headers or {})
102
+ if not request .query_params .get ("code" ):
107
103
raise SSOLoginError (400 , "'code' parameter was not found in callback request" )
108
104
if self .state != request .query_params .get ("state" ):
109
105
raise SSOLoginError (400 , "'state' parameter does not match" )
110
- return await self .process_login (
111
- code , request , params = params , additional_headers = headers , redirect_uri = redirect_uri
112
- )
113
106
114
- async def process_login (
115
- self ,
116
- code : str ,
117
- request : Request ,
118
- * ,
119
- params : Optional [Dict [str , Any ]] = None ,
120
- additional_headers : Optional [Dict [str , Any ]] = None ,
121
- redirect_uri : Optional [str ] = None ,
122
- ) -> Optional [dict ]:
123
- params = params or {}
124
- additional_headers = additional_headers or {}
125
- additional_headers .update (self .additional_headers or {})
126
107
url = request .url
127
- scheme = url .scheme
128
- if not self .allow_insecure_http and scheme != "https" :
129
- current_url = str (url ).replace ("http://" , "https://" )
130
- scheme = "https"
131
- else :
132
- current_url = str (url )
108
+ scheme = "http" if self .allow_insecure_http else "https"
133
109
current_path = f"{ scheme } ://{ url .netloc } { url .path } "
110
+ current_path = re .sub (r"^https?" , scheme , current_path )
111
+ current_url = re .sub (r"^https?" , scheme , str (url ))
134
112
135
- token_url , headers , body = self .oauth_client .prepare_token_request (
113
+ token_url , headers , content = self .oauth_client .prepare_token_request (
136
114
self .token_endpoint ,
137
115
authorization_response = current_url ,
138
116
redirect_url = redirect_uri or self .redirect_uri or current_path ,
139
- code = code ,
117
+ code = request . query_params . get ( " code" ) ,
140
118
** params ,
141
119
)
142
120
143
- if token_url is None :
144
- return None
145
-
146
121
headers .update (additional_headers )
147
-
148
122
auth = httpx .BasicAuth (self .client_id , self .client_secret )
149
123
async with httpx .AsyncClient () as session :
150
- response = await session .post (token_url , headers = headers , content = body , auth = auth )
151
- content = response .json ()
152
- self .oauth_client .parse_request_body_response (json .dumps (content ))
124
+ response = await session .post (token_url , headers = headers , content = content , auth = auth )
125
+ self .oauth_client .parse_request_body_response (json .dumps (response .json ()))
153
126
154
- uri , headers , _ = self .oauth_client .add_token (self .userinfo_endpoint )
155
- response = await session .get (uri , headers = headers )
127
+ url , headers , _ = self .oauth_client .add_token (self .userinfo_endpoint )
128
+ response = await session .get (url , headers = headers )
156
129
content = response .json ()
157
130
158
131
return content
0 commit comments