1
1
import json
2
2
import os
3
- import sys
4
- import warnings
5
3
from typing import Any , Dict , List , Optional
6
4
7
5
import httpx
10
8
from starlette .requests import Request
11
9
from starlette .responses import RedirectResponse
12
10
13
- if sys .version_info >= (3 , 8 ):
14
- from typing import TypedDict
15
- else :
16
- from typing_extensions import TypedDict
17
-
18
- DiscoveryDocument = TypedDict (
19
- "DiscoveryDocument" , {"authorization_endpoint" : str , "token_endpoint" : str , "userinfo_endpoint" : str }
20
- )
21
-
22
11
23
12
class UnsetStateWarning (UserWarning ):
24
13
"""Warning about unset state parameter"""
@@ -41,13 +30,16 @@ class SSOBase:
41
30
_oauth_client : Optional [WebApplicationClient ] = None
42
31
additional_headers : Optional [Dict [str , Any ]] = None
43
32
33
+ authorization_endpoint : str = NotImplemented
34
+ token_endpoint : str = NotImplemented
35
+ userinfo_endpoint : str = NotImplemented
36
+
44
37
def __init__ (
45
38
self ,
46
39
client_id : str ,
47
40
client_secret : str ,
48
41
redirect_uri : Optional [str ] = None ,
49
42
allow_insecure_http : bool = False ,
50
- use_state : bool = False ,
51
43
scope : Optional [List [str ]] = None ,
52
44
):
53
45
self .client_id = client_id
@@ -56,33 +48,11 @@ def __init__(
56
48
self .allow_insecure_http = allow_insecure_http
57
49
if allow_insecure_http :
58
50
os .environ ["OAUTHLIB_INSECURE_TRANSPORT" ] = "1"
59
- # TODO: Remove use_state argument and attribute
60
- if use_state :
61
- warnings .warn (
62
- (
63
- "Argument 'use_state' of SSOBase's constructor is deprecated and will be removed in "
64
- "future releases. Use 'state' argument of individual methods instead."
65
- ),
66
- DeprecationWarning ,
67
- )
68
51
self .scope = scope or self .scope
69
- self ._refresh_token : Optional [str ] = None
70
- self ._state : Optional [str ] = None
71
-
72
- @property
73
- def state (self ) -> Optional [str ]:
74
- """Gets state as it was returned from the server"""
75
- if self ._state is None :
76
- warnings .warn (
77
- "'state' parameter is unset. This means the server either "
78
- "didn't return state (was this expected?) or 'verify_and_process' hasn't been called yet." ,
79
- UnsetStateWarning ,
80
- )
81
- return self ._state
52
+ self .state : Optional [str ] = None
82
53
83
54
@property
84
55
def oauth_client (self ) -> WebApplicationClient :
85
- """OAuth Client to help us generate requests and parse responses"""
86
56
if self .client_id == NotImplemented :
87
57
raise NotImplementedError (f"Provider { self .provider } not supported" )
88
58
if self ._oauth_client is None :
@@ -91,55 +61,29 @@ def oauth_client(self) -> WebApplicationClient:
91
61
92
62
@property
93
63
def access_token (self ) -> Optional [str ]:
94
- """Access token from token endpoint"""
95
64
return self .oauth_client .access_token
96
65
97
66
@property
98
67
def refresh_token (self ) -> Optional [str ]:
99
- """Get refresh token (if returned from provider)"""
100
- return self ._refresh_token or self .oauth_client .refresh_token
68
+ return self .oauth_client .refresh_token
101
69
102
70
@classmethod
103
71
async def openid_from_response (cls , response : dict ) -> dict :
104
- """Return {dict} object from provider's user info endpoint response"""
105
72
raise NotImplementedError (f"Provider { cls .provider } not supported" )
106
73
107
- async def get_discovery_document (self ) -> DiscoveryDocument :
108
- """Get discovery document containing handy urls"""
109
- raise NotImplementedError (f"Provider { self .provider } not supported" )
110
-
111
- @property
112
- async def authorization_endpoint (self ) -> Optional [str ]:
113
- """Return `authorization_endpoint` from discovery document"""
114
- discovery = await self .get_discovery_document ()
115
- return discovery .get ("authorization_endpoint" )
116
-
117
- @property
118
- async def token_endpoint (self ) -> Optional [str ]:
119
- """Return `token_endpoint` from discovery document"""
120
- discovery = await self .get_discovery_document ()
121
- return discovery .get ("token_endpoint" )
122
-
123
- @property
124
- async def userinfo_endpoint (self ) -> Optional [str ]:
125
- """Return `userinfo_endpoint` from discovery document"""
126
- discovery = await self .get_discovery_document ()
127
- return discovery .get ("userinfo_endpoint" )
128
-
129
74
async def get_login_url (
130
75
self ,
131
76
* ,
132
77
redirect_uri : Optional [str ] = None ,
133
78
params : Optional [Dict [str , Any ]] = None ,
134
79
state : Optional [str ] = None ,
135
80
) -> Any :
136
- """Return prepared login url. This is low-level, see {get_login_redirect} instead."""
137
81
params = params or {}
138
82
redirect_uri = redirect_uri or self .redirect_uri
139
83
if redirect_uri is None :
140
84
raise ValueError ("redirect_uri must be provided, either at construction or request time" )
141
85
return self .oauth_client .prepare_request_uri (
142
- await self .authorization_endpoint , redirect_uri = redirect_uri , state = state , scope = self .scope , ** params
86
+ self .authorization_endpoint , redirect_uri = redirect_uri , state = state , scope = self .scope , ** params
143
87
)
144
88
145
89
async def get_login_redirect (
@@ -149,20 +93,8 @@ async def get_login_redirect(
149
93
params : Optional [Dict [str , Any ]] = None ,
150
94
state : Optional [str ] = None ,
151
95
) -> RedirectResponse :
152
- """Return redirect response by Starlette to login page of Oauth SSO provider
153
-
154
- Arguments:
155
- redirect_uri {Optional[str]} -- Override redirect_uri specified on this instance (default: None)
156
- params {Optional[Dict[str, Any]]} -- Add additional query parameters to the login request.
157
- state {Optional[str]} -- Add state parameter. This is useful if you want
158
- the server to return something specific back to you.
159
-
160
- Returns:
161
- RedirectResponse -- Starlette response (may directly be returned from FastAPI)
162
- """
163
96
login_uri = await self .get_login_url (redirect_uri = redirect_uri , params = params , state = state )
164
- response = RedirectResponse (login_uri , 303 )
165
- return response
97
+ return RedirectResponse (login_uri , 303 )
166
98
167
99
async def verify_and_process (
168
100
self ,
@@ -172,21 +104,11 @@ async def verify_and_process(
172
104
headers : Optional [Dict [str , Any ]] = None ,
173
105
redirect_uri : Optional [str ] = None ,
174
106
) -> Optional [dict ]:
175
- """Get FastAPI (Starlette) Request object and process login.
176
- This handler should be used for your /callback path.
177
-
178
- Arguments:
179
- request {Request} -- FastAPI request object (or Starlette)
180
- params {Optional[Dict[str, Any]]} -- Optional additional query parameters to pass to the provider
181
-
182
- Returns:
183
- Optional[dict] -- dict if the login was successfully
184
- """
185
107
headers = headers or {}
186
108
code = request .query_params .get ("code" )
187
109
if code is None :
188
110
raise SSOLoginError (400 , "'code' parameter was not found in callback request" )
189
- self ._state = request .query_params .get ("state" )
111
+ self .state = request .query_params .get ("state" )
190
112
return await self .process_login (
191
113
code , request , params = params , additional_headers = headers , redirect_uri = redirect_uri
192
114
)
@@ -200,13 +122,6 @@ async def process_login(
200
122
additional_headers : Optional [Dict [str , Any ]] = None ,
201
123
redirect_uri : Optional [str ] = None ,
202
124
) -> Optional [dict ]:
203
- """This method should be called from callback endpoint to verify the user and request user info endpoint.
204
- This is low level, you should use {verify_and_process} instead.
205
-
206
- Arguments:
207
- params {Optional[Dict[str, Any]]} -- Optional additional query parameters to pass to the provider
208
- additional_headers {Optional[Dict[str, Any]]} -- Optional additional headers to be added to all requests
209
- """
210
125
params = params or {}
211
126
additional_headers = additional_headers or {}
212
127
additional_headers .update (self .additional_headers or {})
@@ -220,7 +135,7 @@ async def process_login(
220
135
current_path = f"{ scheme } ://{ url .netloc } { url .path } "
221
136
222
137
token_url , headers , body = self .oauth_client .prepare_token_request (
223
- await self .token_endpoint ,
138
+ self .token_endpoint ,
224
139
authorization_response = current_url ,
225
140
redirect_url = redirect_uri or self .redirect_uri or current_path ,
226
141
code = code ,
@@ -236,11 +151,10 @@ async def process_login(
236
151
async with httpx .AsyncClient () as session :
237
152
response = await session .post (token_url , headers = headers , content = body , auth = auth )
238
153
content = response .json ()
239
- self ._refresh_token = content .get ("refresh_token" )
240
154
self .oauth_client .parse_request_body_response (json .dumps (content ))
241
155
242
- uri , headers , _ = self .oauth_client .add_token (await self .userinfo_endpoint )
156
+ uri , headers , _ = self .oauth_client .add_token (self .userinfo_endpoint )
243
157
response = await session .get (uri , headers = headers )
244
158
content = response .json ()
245
159
246
- return await self . openid_from_response ( content )
160
+ return content
0 commit comments