1
1
import json
2
- import os
3
2
import re
4
- from typing import Any , Dict , List , Optional
3
+ from typing import Any
4
+ from typing import Dict
5
+ from typing import List
6
+ from typing import Optional
7
+ from urllib .parse import urljoin
5
8
6
9
import httpx
7
10
from oauthlib .oauth2 import WebApplicationClient
11
+ from starlette .exceptions import HTTPException
8
12
from starlette .requests import Request
9
13
from starlette .responses import RedirectResponse
10
14
11
- from .config import JWT_EXPIRES
12
- from .exceptions import OAuth2LoginError
13
- from .utils import jwt_create
15
+ from .client import OAuth2Client
16
+
17
+
18
+ class OAuth2LoginError (HTTPException ):
19
+ """Raised when any login-related error occurs
20
+ (such as when user is not verified or if there was an attempt for fake login)
21
+ """
14
22
15
23
16
24
class OAuth2Core :
@@ -19,7 +27,7 @@ class OAuth2Core:
19
27
client_id : str = None
20
28
client_secret : str = None
21
29
callback_url : Optional [str ] = None
22
- allow_insecure_http : bool = False
30
+ allow_http : bool = False
23
31
scope : Optional [List [str ]] = None
24
32
state : Optional [str ] = None
25
33
_oauth_client : Optional [WebApplicationClient ] = None
@@ -29,55 +37,47 @@ class OAuth2Core:
29
37
token_endpoint : str = None
30
38
userinfo_endpoint : str = None
31
39
32
- def __init__ (
33
- self ,
34
- client_id : str ,
35
- client_secret : str ,
36
- callback_url : Optional [str ] = None ,
37
- allow_insecure_http : bool = False ,
38
- scope : Optional [List [str ]] = None ,
39
- ):
40
- self .client_id = client_id
41
- self .client_secret = client_secret
42
- self .callback_url = callback_url
43
- self .allow_insecure_http = allow_insecure_http
44
- if allow_insecure_http :
45
- os .environ ["OAUTHLIB_INSECURE_TRANSPORT" ] = "1"
46
- self .scope = scope or self .scope
40
+ def __init__ (self , client : OAuth2Client ) -> None :
41
+ self .client_id = client .client_id
42
+ self .client_secret = client .client_secret
43
+ self .scope = client .scope or self .scope
44
+ self .provider = client .backend .name
45
+ self .authorization_endpoint = client .backend .AUTHORIZATION_URL
46
+ self .token_endpoint = client .backend .ACCESS_TOKEN_URL
47
+ self .userinfo_endpoint = "https://api.github.com/user"
48
+ self .additional_headers = {"Content-Type" : "application/x-www-form-urlencoded" , "Accept" : "application/json" }
47
49
48
50
@property
49
51
def oauth_client (self ) -> WebApplicationClient :
50
52
if self ._oauth_client is None :
51
53
self ._oauth_client = WebApplicationClient (self .client_id )
52
54
return self ._oauth_client
53
55
54
- @property
55
- def access_token (self ) -> Optional [str ]:
56
- return self .oauth_client .access_token
57
-
58
- @property
59
- def refresh_token (self ) -> Optional [str ]:
60
- return self .oauth_client .refresh_token
56
+ def get_redirect_uri (self , request : Request ) -> str :
57
+ return urljoin (str (request .base_url ), "/oauth2/%s/token" % self .provider )
61
58
62
59
async def get_login_url (
63
60
self ,
61
+ request : Request ,
64
62
* ,
65
63
params : Optional [Dict [str , Any ]] = None ,
66
64
state : Optional [str ] = None ,
67
65
) -> Any :
68
66
self .state = state
69
67
params = params or {}
68
+ redirect_uri = self .get_redirect_uri (request )
70
69
return self .oauth_client .prepare_request_uri (
71
- self .authorization_endpoint , redirect_uri = self . callback_url , state = state , scope = self .scope , ** params
70
+ self .authorization_endpoint , redirect_uri = redirect_uri , state = state , scope = self .scope , ** params
72
71
)
73
72
74
73
async def login_redirect (
75
74
self ,
75
+ request : Request ,
76
76
* ,
77
77
params : Optional [Dict [str , Any ]] = None ,
78
78
state : Optional [str ] = None ,
79
79
) -> RedirectResponse :
80
- login_uri = await self .get_login_url (params = params , state = state )
80
+ login_uri = await self .get_login_url (request , params = params , state = state )
81
81
return RedirectResponse (login_uri , 303 )
82
82
83
83
async def get_token_data (
@@ -96,15 +96,14 @@ async def get_token_data(
96
96
raise OAuth2LoginError (400 , "'state' parameter does not match" )
97
97
98
98
url = request .url
99
- scheme = "http" if self .allow_insecure_http else "https"
100
- current_path = f"{ scheme } ://{ url .netloc } { url .path } "
101
- current_path = re .sub (r"^https?" , scheme , current_path )
99
+ scheme = "http" if self .allow_http else "https"
102
100
current_url = re .sub (r"^https?" , scheme , str (url ))
101
+ redirect_uri = self .get_redirect_uri (request )
103
102
104
103
token_url , headers , content = self .oauth_client .prepare_token_request (
105
104
self .token_endpoint ,
105
+ redirect_url = redirect_uri ,
106
106
authorization_response = current_url ,
107
- redirect_url = self .callback_url or current_path ,
108
107
code = request .query_params .get ("code" ),
109
108
** params ,
110
109
)
@@ -129,13 +128,13 @@ async def token_redirect(
129
128
headers : Optional [Dict [str , Any ]] = None ,
130
129
) -> RedirectResponse :
131
130
token_data = await self .get_token_data (request , params = params , headers = headers )
132
- access_token = jwt_create (token_data )
131
+ access_token = request . auth . jwt_create (token_data )
133
132
response = RedirectResponse (request .base_url )
134
133
response .set_cookie (
135
134
"Authorization" ,
136
135
value = f"Bearer { access_token } " ,
137
- httponly = self .allow_insecure_http ,
138
- max_age = JWT_EXPIRES * 60 ,
139
- expires = JWT_EXPIRES * 60 ,
136
+ httponly = self .allow_http ,
137
+ max_age = request . auth . expires ,
138
+ expires = request . auth . expires ,
140
139
)
141
140
return response
0 commit comments