11import json
22import logging
3+ from datetime import datetime , timedelta , timezone
34
45from okdata .sdk .auth .credentials .client_credentials import ClientCredentialsProvider
56from okdata .sdk .auth .credentials .common import (
67 TokenProviderNotInitialized ,
78 TokenRefreshError ,
89)
910from okdata .sdk .auth .credentials .password_grant import TokenServiceProvider
10- from okdata .sdk .auth .util import is_token_expired
1111from okdata .sdk .exceptions import ApiAuthenticateError
1212from okdata .sdk .file_cache import FileCache
1313
1414log = logging .getLogger ()
1515
1616
17- class Authenticate (object ):
17+ def _is_expired (timestamp ):
18+ """Return true if `timestamp` has expired (or is just about to expire)."""
19+ return timestamp and (timestamp - datetime .now (timezone .utc )).total_seconds () < 10
20+
21+
22+ class Authenticate :
23+ _access_token = None
24+ _refresh_token = None
25+ _expires_at = None
26+ _refresh_expires_at = None
27+
1828 def __init__ (self , config , token_provider = None , file_cache = None ):
1929 self .token_provider = token_provider
2030 if not self .token_provider :
@@ -30,9 +40,6 @@ def __init__(self, config, token_provider=None, file_cache=None):
3040 if not self .file_cache :
3141 self .file_cache = FileCache (config )
3242
33- self ._access_token = None
34- self ._refresh_token = None
35-
3643 def _resolve_token_provider (self , config ):
3744 # Add more TokenProviders to accept different login methods
3845 strategies = [ClientCredentialsProvider , TokenServiceProvider ]
@@ -54,7 +61,7 @@ def access_token(self):
5461 if not self ._access_token :
5562 self .login ()
5663 # If expired, refresh
57- if is_token_expired (self ._access_token ):
64+ if _is_expired (self ._expires_at ):
5865 self .refresh_access_token ()
5966 return self ._access_token
6067
@@ -66,8 +73,12 @@ def login(self, force=False):
6673 if cached :
6774 self ._access_token = cached ["access_token" ]
6875 self ._refresh_token = cached .get ("refresh_token" )
76+ if expires_at := cached .get ("expires_at" ):
77+ self ._expires_at = datetime .fromisoformat (expires_at )
78+ if refresh_expires_at := cached .get ("refresh_expires_at" ):
79+ self ._refresh_expires_at = datetime .fromisoformat (refresh_expires_at )
6980
70- if self ._access_token and not is_token_expired (self ._access_token ):
81+ if self ._access_token and not _is_expired (self ._expires_at ):
7182 log .info ("Token not expired, skipping" )
7283 return
7384 self .refresh_access_token ()
@@ -78,7 +89,7 @@ def refresh_access_token(self):
7889
7990 tokens = None
8091
81- if self ._refresh_token and not is_token_expired (self ._refresh_token ):
92+ if self ._refresh_token and not _is_expired (self ._refresh_expires_at ):
8293 try :
8394 tokens = self .token_provider .refresh_token (self ._refresh_token )
8495 except TokenRefreshError as e :
@@ -89,6 +100,13 @@ def refresh_access_token(self):
89100 if "access_token" not in tokens :
90101 raise ApiAuthenticateError
91102 self ._refresh_token = tokens .get ("refresh_token" )
103+ self ._expires_at = datetime .now (timezone .utc ) + timedelta (
104+ seconds = tokens .get ("expires_in" )
105+ )
106+ if refresh_expires_in := tokens .get ("refresh_expires_in" ):
107+ self ._refresh_expires_at = datetime .now (timezone .utc ) + timedelta (
108+ seconds = refresh_expires_in
109+ )
92110
93111 self ._access_token = tokens ["access_token" ]
94112 self .file_cache .write_credentials (credentials = self )
@@ -99,6 +117,12 @@ def __repr__(self):
99117 "provider" : self .token_provider .__class__ .__name__ ,
100118 "access_token" : self ._access_token ,
101119 "refresh_token" : self ._refresh_token ,
120+ "expires_at" : self ._expires_at .isoformat () if self ._expires_at else "" ,
121+ "refresh_expires_at" : (
122+ self ._refresh_expires_at .isoformat ()
123+ if self ._refresh_expires_at
124+ else ""
125+ ),
102126 }
103127 )
104128
0 commit comments