1
- """ Token class is a front-end to the TokenDB Database.
2
-
3
- Long-term user tokens are stored here, which can be used to obtain new tokens.
1
+ """ Auth class is a front-end to the Auth Database
4
2
"""
5
3
from __future__ import absolute_import
6
4
from __future__ import division
28
26
29
27
30
28
class Token (Model , OAuth2TokenMixin ):
31
- """This class describe token fields"""
32
-
33
29
__tablename__ = "Token"
34
30
__table_args__ = {"mysql_engine" : "InnoDB" , "mysql_charset" : "utf8" }
35
31
# access_token too large for varchar(255)
36
32
# 767 bytes is the stated prefix limitation for InnoDB tables in MySQL version 5.6
37
33
# https://stackoverflow.com/questions/1827063/mysql-error-key-specification-without-a-key-length
38
- id = Column (Integer , autoincrement = True , primary_key = True ) # Unique token ID
39
- kid = Column (String (255 )) # Unique secret key ID for token encryption
40
- user_id = Column (String (255 )) # User identificator that registred in an identity provider, token owner
41
- provider = Column (String (255 )) # Provider name registred in DIRAC
42
- expires_at = Column (Integer , nullable = False , default = 0 ) # When the access token is expired
34
+ id = Column (Integer , autoincrement = True , primary_key = True )
35
+ kid = Column (String (255 ))
36
+ user_id = Column (String (255 ))
37
+ provider = Column (String (255 ))
38
+ expires_at = Column (Integer , nullable = False , default = 0 )
43
39
access_token = Column (Text , nullable = False )
44
40
refresh_token = Column (Text , nullable = False )
45
- rt_expires_at = Column (Integer , nullable = False , default = 0 ) # When the refresh token is expired
41
+ rt_expires_at = Column (Integer , nullable = False , default = 0 )
46
42
47
43
48
44
class TokenDB (SQLAlchemyDB ):
@@ -58,10 +54,7 @@ def __init__(self):
58
54
self .session = scoped_session (self .sessionMaker_o )
59
55
60
56
def __initializeDB (self ):
61
- """Create the tables
62
-
63
- :return: S_OK()/S_ERROR()
64
- """
57
+ """Create the tables"""
65
58
tablesInDB = self .inspector .get_table_names ()
66
59
67
60
# Token
@@ -79,7 +72,7 @@ def getTokenForUserProvider(self, userID, provider):
79
72
:param str userID: user ID
80
73
:param str provider: provider
81
74
82
- :return: S_OK(OAuth2Token )/S_ERROR() -- return an OAuth2Token object, which is also a dict
75
+ :return: S_OK(dict )/S_ERROR()
83
76
"""
84
77
session = self .session ()
85
78
try :
@@ -95,40 +88,34 @@ def getTokenForUserProvider(self, userID, provider):
95
88
return self .__result (session , S_OK (OAuth2Token (self .__rowToDict (token )) if token else None ))
96
89
97
90
def updateToken (self , token , userID , provider , rt_expired_in ):
98
- """Update tokens for user and identity provider
91
+ """Update tokens
99
92
100
93
:param dict token: token info
101
- :param str userID: user ID that comes from identity provider
102
- :param str provider: provider name
94
+ :param str userID: user ID
95
+ :param str provider: provider
103
96
:param int rt_expired_in: refresh token lifetime
104
97
105
- :return: S_OK(list)/S_ERROR() -- return old tokens that should be revoked.
98
+ :return: S_OK(list)/S_ERROR()
106
99
"""
107
- # Prepare a token to write to the database
108
100
token ["user_id" ] = userID
109
101
token ["provider" ] = provider
110
- # If the token expiration date is not specified, we will try to determine it
111
102
if not token .get ("rt_expires_at" ):
112
103
try :
113
- # This value can be contained in the token itself if it is a JWT
114
104
token ["rt_expires_at" ] = int (
115
105
jwt .decode (token ["refresh_token" ], options = dict (verify_signature = False , verify_aud = False ))["exp" ]
116
106
)
117
107
except Exception as e :
118
108
self .log .debug ("Cannot get refresh token expires time: %s" % repr (e ))
119
- # Otherwise, we set this value
109
+
120
110
token ["rt_expires_at" ] = int (token .get ("rt_expires_at" , rt_expired_in + int (time .time ())))
121
- # We ignore expired tokens
122
111
if token ["rt_expires_at" ] < time .time ():
123
112
return S_ERROR ("Cannot store expired refresh token." )
124
113
125
114
attrts = dict ((k , v ) for k , v in dict (token ).items () if k in list (Token .__dict__ .keys ()))
126
115
self .log .debug ("Store token:" , pprint .pformat (attrts ))
127
116
session = self .session ()
128
117
try :
129
- # Remove expired tokens
130
118
session .query (Token ).filter (Token .expires_at < time .time ()).delete ()
131
- # When we update existing tokens, the old tokens should be revoked
132
119
oldTokens = session .query (Token ).filter (Token .user_id == userID ).filter (Token .provider == provider ).all ()
133
120
session .add (Token (** attrts ))
134
121
session .query (Token ).filter (Token .user_id == userID ).filter (Token .provider == provider ).filter (
@@ -141,12 +128,12 @@ def updateToken(self, token, userID, provider, rt_expired_in):
141
128
return self .__result (session , S_OK ([self .__rowToDict (t ) for t in oldTokens ] if oldTokens else []))
142
129
143
130
def removeToken (self , access_token = None , refresh_token = None , user_id = None ):
144
- """Remove token from DB
131
+ """Remove token
145
132
146
133
:param str access_token: access token
147
134
:param str refresh_token: refresh token
148
135
149
- :return: S_OK(str )/S_ERROR()
136
+ :return: S_OK(object )/S_ERROR()
150
137
"""
151
138
session = self .session ()
152
139
try :
@@ -161,12 +148,6 @@ def removeToken(self, access_token=None, refresh_token=None, user_id=None):
161
148
return self .__result (session , S_OK ("Token successfully removed" ))
162
149
163
150
def getTokensByUserID (self , userID ):
164
- """Return tokens for user ID
165
-
166
- :param str userID: user ID that return identity provider
167
-
168
- :return: S_OK(list)/S_ERROR() -- tokens as OAuth2Token objects
169
- """
170
151
session = self .session ()
171
152
try :
172
153
tokens = session .query (Token ).filter (Token .user_id == userID ).all ()
@@ -177,13 +158,6 @@ def getTokensByUserID(self, userID):
177
158
return self .__result (session , S_OK ([OAuth2Token (self .__rowToDict (t )) for t in tokens ]))
178
159
179
160
def __result (self , session , result = None ):
180
- """Helper method
181
-
182
- :param session: session instance
183
- :param result: DIRAC result
184
-
185
- :return: S_OK()/S_ERROR()
186
- """
187
161
try :
188
162
if not result ["OK" ]:
189
163
session .rollback ()
0 commit comments