6
6
import threading
7
7
import time
8
8
import urllib .parse as urllib_parse
9
- from collections import namedtuple , defaultdict
9
+ from collections import defaultdict
10
10
from typing import Dict
11
11
12
12
import tornado
13
13
import tornado .ioloop
14
14
from tornado import httpclient , escape
15
15
16
16
from auth import auth_base
17
- from auth .auth_base import AuthFailureError , AuthBadRequestException
17
+ from auth .auth_base import AuthFailureError , AuthBadRequestException , AuthRejectedError
18
18
from model import model_helper
19
19
from model .model_helper import read_bool_from_config , read_int_from_config
20
20
from model .server_conf import InvalidServerConfigException
21
21
from utils import file_utils
22
+ from utils .tornado_utils import get_secure_cookie
22
23
23
24
LOGGER = logging .getLogger ('script_server.AbstractOauthAuthenticator' )
24
25
@@ -31,7 +32,21 @@ def __init__(self, username) -> None:
31
32
self .last_visit = None
32
33
33
34
34
- _OauthUserInfo = namedtuple ('_OauthUserInfo' , ['email' , 'enabled' , 'oauth_response' ])
35
+ class _OauthUserInfo :
36
+ def __init__ (self , username , enabled , oauth_response , eager_groups = None ):
37
+ self .username = username
38
+ self .enabled = enabled
39
+ self .oauth_response = oauth_response
40
+ self .eager_groups = eager_groups
41
+
42
+ def __eq__ (self , o : object ) -> bool :
43
+ return isinstance (o , _OauthUserInfo ) and (self .username == o .username )
44
+
45
+ def __str__ (self ) -> str :
46
+ return f'_OauthUserInfo({ self .username } )'
47
+
48
+ def __repr__ (self ) -> str :
49
+ return f'_OauthUserInfo({ self .__dict__ } )'
35
50
36
51
37
52
def _start_timer (callback ):
@@ -67,6 +82,8 @@ def __init__(self, oauth_authorize_url, oauth_token_url, oauth_scope, params_dic
67
82
self ._users = {} # type: Dict[str, _UserState]
68
83
self ._user_locks = defaultdict (lambda : asyncio .locks .Lock ())
69
84
85
+ self .http_client = httpclient .AsyncHTTPClient ()
86
+
70
87
self .timer = None
71
88
if self .dump_file :
72
89
self ._restore_state ()
@@ -88,28 +105,26 @@ async def authenticate(self, request_handler):
88
105
LOGGER .error ('Code is not specified' )
89
106
raise AuthBadRequestException ('Missing authorization information. Please contact your administrator' )
90
107
91
- access_token = await self .fetch_access_token (code , request_handler )
108
+ ( access_token , refresh_token ) = await self .fetch_access_token (code , request_handler )
92
109
user_info = await self .fetch_user_info (access_token )
93
110
94
- user_email = user_info .email
95
- if not user_email :
111
+ username = user_info .username
112
+ if not username :
96
113
error_message = 'No email field in user response. The response: ' + str (user_info .oauth_response )
97
114
LOGGER .error (error_message )
98
115
raise AuthFailureError (error_message )
99
116
100
117
if not user_info .enabled :
101
118
error_message = 'User %s is not enabled in OAuth provider. The response: %s' \
102
- % (user_email , str (user_info .oauth_response ))
119
+ % (username , str (user_info .oauth_response ))
103
120
LOGGER .error (error_message )
104
121
raise AuthFailureError (error_message )
105
122
106
- user_state = _UserState (user_email )
107
- self ._users [user_email ] = user_state
123
+ user_state = _UserState (username )
124
+ self ._users [username ] = user_state
108
125
109
126
if self .group_support :
110
- user_groups = await self .fetch_user_groups (access_token )
111
- LOGGER .info ('Loaded groups for ' + user_email + ': ' + str (user_groups ))
112
- user_state .groups = user_groups
127
+ await self .load_groups (access_token , username , user_info , user_state )
113
128
114
129
now = time .time ()
115
130
@@ -119,7 +134,15 @@ async def authenticate(self, request_handler):
119
134
120
135
user_state .last_visit = now
121
136
122
- return user_email
137
+ return username
138
+
139
+ async def load_groups (self , access_token , username , user_info , user_state ):
140
+ if user_info .eager_groups is not None :
141
+ user_state .groups = user_info .eager_groups
142
+ else :
143
+ user_groups = await self .fetch_user_groups (access_token )
144
+ user_state .groups = user_groups
145
+ LOGGER .info ('Loaded groups for ' + username + ': ' + str (user_state .groups ))
123
146
124
147
def validate_user (self , user , request_handler ):
125
148
if not user :
@@ -146,7 +169,7 @@ def validate_user(self, user, request_handler):
146
169
user_state .last_visit = now
147
170
148
171
if self .auth_info_ttl :
149
- access_token = request_handler . get_secure_cookie ('token' )
172
+ access_token = get_secure_cookie (request_handler , 'token' )
150
173
if access_token is None :
151
174
LOGGER .info ('User %s token is not available' , user )
152
175
return False
@@ -180,8 +203,8 @@ async def fetch_access_token(self, code, request_handler):
180
203
'client_secret' : self .secret ,
181
204
'grant_type' : 'authorization_code' ,
182
205
})
183
- http_client = httpclient . AsyncHTTPClient ()
184
- response = await http_client .fetch (
206
+
207
+ response = await self . http_client .fetch (
185
208
self .oauth_token_url ,
186
209
method = 'POST' ,
187
210
headers = {'Content-Type' : 'application/x-www-form-urlencoded' },
@@ -206,13 +229,14 @@ async def fetch_access_token(self, code, request_handler):
206
229
207
230
response_values = escape .json_decode (response .body )
208
231
access_token = response_values .get ('access_token' )
232
+ refresh_token = response_values .get ('refresh_token' )
209
233
210
234
if not access_token :
211
235
message = 'No access token in response: ' + str (response .body )
212
236
LOGGER .error (message )
213
237
raise AuthFailureError (message )
214
238
215
- return access_token
239
+ return access_token , refresh_token
216
240
217
241
def update_user_auth (self , username , user_state , access_token ):
218
242
now = time .time ()
@@ -242,8 +266,14 @@ async def _do_update_user_auth_async(self, username, user_state, access_token):
242
266
243
267
LOGGER .info ('User %s state expired, refreshing' , username )
244
268
245
- user_info = await self .fetch_user_info (access_token ) # type: _OauthUserInfo
246
- if (not user_info ) or (not user_info .email ):
269
+ try :
270
+ user_info = await self .fetch_user_info (access_token ) # type: _OauthUserInfo
271
+ except AuthRejectedError :
272
+ LOGGER .info (f'User { username } is not authenticated anymore. Logging out' )
273
+ self ._remove_user (username )
274
+ return
275
+
276
+ if (not user_info ) or (not user_info .username ):
247
277
LOGGER .error ('Failed to fetch user info: %s' , str (user_info ))
248
278
self ._remove_user (username )
249
279
return
@@ -256,9 +286,7 @@ async def _do_update_user_auth_async(self, username, user_state, access_token):
256
286
257
287
if self .group_support :
258
288
try :
259
- user_groups = await self .fetch_user_groups (access_token )
260
- LOGGER .info ('Updated groups for ' + username + ': ' + str (user_groups ))
261
- user_state .groups = user_groups
289
+ await self .load_groups (access_token , username , user_info , user_state )
262
290
except AuthFailureError :
263
291
LOGGER .error ('Failed to fetch user %s groups' , username )
264
292
self ._remove_user (username )
0 commit comments