1
1
import abc
2
2
import asyncio
3
+ import datetime
3
4
import json
4
5
import logging
5
6
import os
15
16
16
17
from auth import auth_base
17
18
from auth .auth_base import AuthFailureError , AuthBadRequestException , AuthRejectedError
19
+ from auth .oauth_token_manager import OAuthTokenManager
20
+ from auth .oauth_token_response import OAuthTokenResponse
18
21
from model import model_helper
19
22
from model .model_helper import read_bool_from_config , read_int_from_config
20
23
from model .server_conf import InvalidServerConfigException
21
24
from utils import file_utils
22
- from utils .tornado_utils import get_secure_cookie
23
25
24
26
LOGGER = logging .getLogger ('script_server.AbstractOauthAuthenticator' )
25
27
@@ -90,6 +92,12 @@ def __init__(self, oauth_authorize_url, oauth_token_url, oauth_scope, params_dic
90
92
91
93
self ._schedule_dump_task ()
92
94
95
+ self ._token_manager = OAuthTokenManager (
96
+ enabled = bool (self .auth_info_ttl ),
97
+ fetch_token_callback = self ._fetch_token_by_refresh )
98
+
99
+ self .ioloop = tornado .ioloop .IOLoop .current ()
100
+
93
101
@staticmethod
94
102
def _validate_dump_file (dump_file ):
95
103
if os .path .isdir (dump_file ):
@@ -105,8 +113,8 @@ async def authenticate(self, request_handler):
105
113
LOGGER .error ('Code is not specified' )
106
114
raise AuthBadRequestException ('Missing authorization information. Please contact your administrator' )
107
115
108
- ( access_token , refresh_token ) = await self .fetch_access_token (code , request_handler )
109
- user_info = await self .fetch_user_info (access_token )
116
+ token_response = await self .fetch_access_token_by_code (code , request_handler )
117
+ user_info = await self .fetch_user_info (token_response . access_token )
110
118
111
119
username = user_info .username
112
120
if not username :
@@ -124,12 +132,13 @@ async def authenticate(self, request_handler):
124
132
self ._users [username ] = user_state
125
133
126
134
if self .group_support :
127
- await self .load_groups (access_token , username , user_info , user_state )
135
+ await self .load_groups (token_response . access_token , username , user_info , user_state )
128
136
129
137
now = time .time ()
130
138
139
+ self ._token_manager .update_tokens (token_response , username , request_handler )
140
+
131
141
if self .auth_info_ttl :
132
- request_handler .set_secure_cookie ('token' , access_token )
133
142
user_state .last_auth_update = now
134
143
135
144
user_state .last_visit = now
@@ -144,23 +153,28 @@ async def load_groups(self, access_token, username, user_info, user_state):
144
153
user_state .groups = user_groups
145
154
LOGGER .info ('Loaded groups for ' + username + ': ' + str (user_state .groups ))
146
155
147
- def validate_user (self , user , request_handler ):
156
+ async def validate_user (self , user , request_handler ):
148
157
if not user :
149
158
LOGGER .warning ('Username is not available' )
150
159
return False
151
160
152
161
now = time .time ()
153
162
154
163
user_state = self ._users .get (user )
164
+ validate_expiration = True
155
165
if not user_state :
156
166
# if nothing is enabled, it's ok not to have user state (e.g. after server restart)
157
167
if self .session_expire <= 0 and not self .auth_info_ttl and not self .group_support :
158
168
return True
169
+ elif self ._token_manager .can_restore_state (request_handler ):
170
+ validate_expiration = False
171
+ user_state = _UserState (user )
172
+ self ._users [user ] = user_state
159
173
else :
160
174
LOGGER .info ('User %s state is missing' , user )
161
175
return False
162
176
163
- if self .session_expire > 0 :
177
+ if ( self .session_expire > 0 ) and validate_expiration :
164
178
last_visit = user_state .last_visit
165
179
if (last_visit is None ) or ((last_visit + self .session_expire ) < now ):
166
180
LOGGER .info ('User %s state is expired' , user )
@@ -169,9 +183,10 @@ def validate_user(self, user, request_handler):
169
183
user_state .last_visit = now
170
184
171
185
if self .auth_info_ttl :
172
- access_token = get_secure_cookie ( request_handler , 'token' )
186
+ access_token = await self . _token_manager . synchronize_user_tokens ( user , request_handler )
173
187
if access_token is None :
174
188
LOGGER .info ('User %s token is not available' , user )
189
+ self ._remove_user (user )
175
190
return False
176
191
177
192
self .update_user_auth (user , user_state , access_token )
@@ -186,57 +201,40 @@ def get_groups(self, user, known_groups=None):
186
201
return user_state .groups
187
202
188
203
def logout (self , user , request_handler ):
189
- request_handler . clear_cookie ( 'token' )
204
+ self . _token_manager . logout ( user , request_handler )
190
205
self ._remove_user (user )
191
206
192
207
self ._dump_state ()
193
208
194
209
def _remove_user (self , user ):
195
210
if user in self ._users :
196
211
del self ._users [user ]
212
+ self ._token_manager .remove_user (user )
197
213
198
- async def fetch_access_token (self , code , request_handler ):
199
- body = urllib_parse . urlencode ({
214
+ async def fetch_access_token_by_code (self , code , request_handler ):
215
+ return await self . _fetch_token ({
200
216
'redirect_uri' : get_path_for_redirect (request_handler ),
201
217
'code' : code ,
202
218
'client_id' : self .client_id ,
203
219
'client_secret' : self .secret ,
204
220
'grant_type' : 'authorization_code' ,
205
221
})
206
222
207
- response = await self .http_client .fetch (
208
- self .oauth_token_url ,
209
- method = 'POST' ,
210
- headers = {'Content-Type' : 'application/x-www-form-urlencoded' },
211
- body = body ,
212
- raise_error = False )
213
-
214
- response_values = {}
215
- if response .body :
216
- response_values = escape .json_decode (response .body )
217
-
218
- if response .error :
219
- if response_values .get ('error_description' ):
220
- error_text = response_values .get ('error_description' )
221
- elif response_values .get ('error' ):
222
- error_text = response_values .get ('error' )
223
- else :
224
- error_text = str (response .error )
225
-
226
- error_message = 'Failed to load access_token: ' + error_text
227
- LOGGER .error (error_message )
228
- raise AuthFailureError (error_message )
229
-
230
- response_values = escape .json_decode (response .body )
231
- access_token = response_values .get ('access_token' )
232
- refresh_token = response_values .get ('refresh_token' )
233
-
234
- if not access_token :
235
- message = 'No access token in response: ' + str (response .body )
236
- LOGGER .error (message )
237
- raise AuthFailureError (message )
238
-
239
- return access_token , refresh_token
223
+ async def _fetch_token_by_refresh (self , refresh_token , username ):
224
+ if username not in self ._users :
225
+ return None
226
+
227
+ try :
228
+ return await self ._fetch_token ({
229
+ 'refresh_token' : refresh_token ,
230
+ 'client_id' : self .client_id ,
231
+ 'client_secret' : self .secret ,
232
+ 'grant_type' : 'refresh_token' ,
233
+ })
234
+ except AuthFailureError :
235
+ LOGGER .info (f'Failed to refresh token for user { username } . Logging out' )
236
+ self ._remove_user (username )
237
+ return None
240
238
241
239
def update_user_auth (self , username , user_state , access_token ):
242
240
now = time .time ()
@@ -246,7 +244,7 @@ def update_user_auth(self, username, user_state, access_token):
246
244
if not ttl_expired :
247
245
return
248
246
249
- tornado .ioloop . IOLoop . current () .spawn_callback (
247
+ self .ioloop .spawn_callback (
250
248
self ._do_update_user_auth_async ,
251
249
username ,
252
250
user_state ,
@@ -342,6 +340,41 @@ def _cleanup(self):
342
340
if self .timer :
343
341
self .timer .cancel ()
344
342
343
+ async def _fetch_token (self , body ):
344
+ encoded_body = urllib_parse .urlencode (body )
345
+
346
+ response = await self .http_client .fetch (
347
+ self .oauth_token_url ,
348
+ method = 'POST' ,
349
+ headers = {'Content-Type' : 'application/x-www-form-urlencoded' },
350
+ body = encoded_body ,
351
+ raise_error = False )
352
+
353
+ response_values = {}
354
+ if response .body :
355
+ response_values = escape .json_decode (response .body )
356
+
357
+ if response .error :
358
+ if response_values .get ('error_description' ):
359
+ error_text = response_values .get ('error_description' )
360
+ elif response_values .get ('error' ):
361
+ error_text = response_values .get ('error' )
362
+ else :
363
+ error_text = str (response .error )
364
+
365
+ error_message = 'Failed to refresh access_token: ' + error_text
366
+ LOGGER .error (error_message )
367
+ raise AuthFailureError (error_message )
368
+
369
+ token_response = OAuthTokenResponse .create (response_values , datetime .datetime .now ())
370
+
371
+ if not token_response .access_token :
372
+ message = 'No access token in response: ' + str (response .body )
373
+ LOGGER .error (message )
374
+ raise AuthFailureError (message )
375
+
376
+ return token_response
377
+
345
378
346
379
def get_path_for_redirect (request_handler ):
347
380
referer = request_handler .request .headers .get ('Referer' )
0 commit comments