2424# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
2525# THE SOFTWARE.
2626#
27-
27+ from __future__ import annotations
2828from . import queryMessage_pb2
2929import sys
30+ import traceback
3031import os
3132import socket
3233import struct
3536import ssl
3637import logging
3738
39+ from threading import Lock
40+ from types import SimpleNamespace
3841from dataclasses import dataclass
3942
4043logger = logging .getLogger (__name__ )
4346PROTOCOL_VERSION = 1
4447
4548
49+ class UnauthorizedException (Exception ):
50+ pass
51+
52+
4653@dataclass
4754class Session ():
4855
@@ -55,7 +62,8 @@ class Session():
5562 def valid (self ) -> bool :
5663 session_age = time .time () - self .session_started
5764
58- if session_age > self .session_token_ttl :
65+ # This triggers refresh if the session is about to expire.
66+ if session_age > self .session_token_ttl - os .getenv ("SESSION_EXPIRTY_OFFSET_SEC" , 10 ):
5967 return False
6068
6169 return True
@@ -81,8 +89,7 @@ class Connector(object):
8189
8290 def __init__ (self , host = "localhost" , port = 55555 ,
8391 user = "" , password = "" , token = "" ,
84- session = None ,
85- use_ssl = True ):
92+ use_ssl = True , shared_data = None ):
8693
8794 self .use_ssl = use_ssl
8895
@@ -93,17 +100,18 @@ def __init__(self, host="localhost", port=55555,
93100 self .last_response = ''
94101 self .last_query_time = 0
95102
96- self .session = None
97-
98103 self ._connect ()
99104
100- if session :
101- self .session = session
102- else :
105+ if shared_data is None :
106+ self .shared_data = SimpleNamespace ()
107+ self .shared_data .session = None
108+ self .shared_data .lock = Lock ()
103109 try :
104110 self ._authenticate (user , password , token )
105111 except Exception as e :
106112 raise Exception ("Authentication failed:" , str (e ))
113+ else :
114+ self .shared_data = shared_data
107115
108116 def __del__ (self ):
109117
@@ -152,39 +160,44 @@ def _authenticate(self, user, password="", token=""):
152160 if session_info ["status" ] != 0 :
153161 raise Exception (session_info ["info" ])
154162
155- self .session = Session (session_info ["session_token" ],
156- session_info ["refresh_token" ],
157- session_info ["session_token_expires_in" ],
158- session_info ["refresh_token_expires_in" ],
159- )
163+ self .shared_data .session = Session (session_info ["session_token" ],
164+ session_info ["refresh_token" ],
165+ session_info ["session_token_expires_in" ],
166+ session_info ["refresh_token_expires_in" ],
167+ time .time ()
168+ )
160169
161170 def _check_session_status (self ):
162-
163- if not self .session :
171+ if not self .shared_data .session :
164172 return
165173
166- if not self .session .valid ():
167- self ._refresh_token ()
174+ if not self .shared_data .session .valid ():
175+ with self .shared_data .lock :
176+ self ._refresh_token ()
168177
169178 def _refresh_token (self ):
170-
171179 query = [{
172180 "RefreshToken" : {
173- "refresh_token" : self .session .refresh_token
181+ "refresh_token" : self .shared_data . session .refresh_token
174182 }
175183 }]
176184
177185 response , _ = self ._query (query , [])
178186
179- session_info = response [0 ]["RefreshToken" ]
180- if session_info ["status" ] != 0 :
181- raise Exception (session_info ["info" ])
182-
183- self .session = Session (session_info ["session_token" ],
184- session_info ["refresh_token" ],
185- session_info ["session_token_expires_in" ],
186- session_info ["refresh_token_expires_in" ],
187- )
187+ logger .info (f"Refresh token response: \r \n { response } " )
188+ if isinstance (response , list ):
189+ session_info = response [0 ]["RefreshToken" ]
190+ if session_info ["status" ] != 0 :
191+ raise UnauthorizedException (response )
192+
193+ self .shared_data .session = Session (session_info ["session_token" ],
194+ session_info ["refresh_token" ],
195+ session_info ["session_token_expires_in" ],
196+ session_info ["refresh_token_expires_in" ],
197+ time .time ()
198+ )
199+ else :
200+ raise UnauthorizedException (response )
188201
189202 def _connect (self ):
190203
@@ -258,8 +271,8 @@ def _query(self, query, blob_array = []):
258271 query_msg .json = query_str
259272
260273 # Set Auth token, only when not authenticated before
261- if self .session :
262- query_msg .token = self .session .session_token
274+ if self .shared_data . session :
275+ query_msg .token = self .shared_data . session .session_token
263276
264277 for blob in blob_array :
265278 query_msg .blobs .append (blob )
@@ -280,21 +293,53 @@ def _query(self, query, blob_array = []):
280293 return (self .last_response , response_blob_array )
281294
282295 def query (self , q , blobs = []):
283-
284- self ._check_session_status ()
285-
296+ """
297+ Query the database with a query string or a json object.
298+ First it checks if the session is valid, if not, it refreshes the token.
299+ Then it sends the query to the server and returns the response.
300+
301+ Args:
302+ q (json): native query to be sent
303+ blobs (list, optional): Blobs if needed with the query. Defaults to [].
304+
305+ Raises:
306+ ConnectionError: Fatal error, connection to server lost
307+
308+ Returns:
309+ _type_: _description_
310+ """
311+ self ._renew_session ()
286312 try :
287313 start = time .time ()
288314 self .response , self .blobs = self ._query (q , blobs )
315+ if not isinstance (self .response , list ) and self .response ["info" ] == "Not Authenticated!" :
316+ # The case where session is valid, but expires while query is sent.
317+ # Hope is that the query send won't be longer than the session ttl.
318+ logger .warn (
319+ f"Session expired while query was sent. Retrying... \r \n { traceback .format_stack (limit = 5 )} " )
320+ self ._renew_session ()
321+ start = time .time ()
322+ self .response , self .blobs = self ._query (q , blobs )
289323 self .last_query_time = time .time () - start
290324 return self .response , self .blobs
291325 except BaseException as e :
292- print (e )
326+ logger . critical (e )
293327 raise ConnectionError ("ApertureDB disconnected" )
294328
295- def create_new_connection (self ):
296-
297- return Connector (self .host , self .port , session = self .session )
329+ def _renew_session (self ):
330+ count = 0
331+ while count < 3 :
332+ try :
333+ self ._check_session_status ()
334+ break
335+ except UnauthorizedException as e :
336+ logger .warn (
337+ f"[Attempt { count + 1 } of 3] Failed to refresh token. Details: \r \n { traceback .format_exc (limit = 5 )} " )
338+ time .sleep (1 )
339+ count += 1
340+
341+ def create_new_connection (self ) -> Connector :
342+ return Connector (self .host , self .port , shared_data = self .shared_data )
298343
299344 def get_last_response_str (self ):
300345
0 commit comments