Skip to content

Commit eed8a07

Browse files
committed
Changes after testing
1 parent c09412e commit eed8a07

File tree

4 files changed

+36
-17
lines changed

4 files changed

+36
-17
lines changed

src/firebolt/async_db/connection.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,7 @@ async def connect(
237237
if not auth:
238238
raise ConfigurationError("auth is required to connect.")
239239

240-
if account_name:
241-
auth._account_name = account_name
240+
auth.account = account_name
242241

243242
api_endpoint = fix_url_schema(api_endpoint)
244243
# Type checks

src/firebolt/client/auth/base.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from abc import abstractmethod
33
from enum import IntEnum
44
from time import time
5-
from typing import AsyncGenerator, Generator, Optional
5+
from typing import AsyncGenerator, Generator, Optional, Tuple
66

77
from anyio import Lock
88
from httpx import Auth as HttpxAuth
@@ -57,6 +57,17 @@ def __init__(self, use_token_cache: bool = True):
5757
self._expires: Optional[int] = None
5858
self._lock = Lock()
5959

60+
@property
61+
def account(self) -> Optional[str]:
62+
return self._account_name
63+
64+
@account.setter
65+
def account(self, value: str) -> None:
66+
self._account_name = value
67+
# Now we have all the elements to fetch the cached token
68+
if not self._token:
69+
self._token, self._expires = self._get_cached_token()
70+
6071
def copy(self) -> "Auth":
6172
"""Make another auth object with same credentials.
6273
@@ -109,7 +120,7 @@ def expired(self) -> bool:
109120
"""
110121
return self._expires is not None and self._expires <= int(time())
111122

112-
def _get_cached_token(self) -> Optional[str]:
123+
def _get_cached_token(self) -> Tuple[Optional[str], Optional[int]]:
113124
"""If caching is enabled, get token from cache.
114125
115126
If caching is disabled, None is returned.
@@ -118,17 +129,17 @@ def _get_cached_token(self) -> Optional[str]:
118129
Optional[str]: Token if any, and if caching is enabled; None otherwise
119130
"""
120131
if not self._use_token_cache:
121-
return None
132+
return (None, None)
122133

123134
cache_key = SecureCacheKey(
124135
[self.principal, self.secret, self._account_name], self.secret
125136
)
126137
connection_info = _firebolt_cache.get(cache_key)
127138

128139
if connection_info and connection_info.token:
129-
return connection_info.token
140+
return (connection_info.token, connection_info.expiry_time)
130141

131-
return None
142+
return (None, None)
132143

133144
def _cache_token(self) -> None:
134145
"""If caching is enabled, cache token."""

src/firebolt/db/connection.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ def connect(
6666
if not auth:
6767
raise ConfigurationError("auth is required to connect.")
6868

69-
if account_name:
70-
auth._account_name = account_name
69+
auth.account = account_name
7170

7271
api_endpoint = fix_url_schema(api_endpoint)
7372
# Type checks

src/firebolt/utils/cache.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,24 @@ def __post_init__(self) -> None:
7171
"""
7272
if self.system_engine and isinstance(self.system_engine, dict):
7373
self.system_engine = EngineInfo(**self.system_engine)
74-
self.databases = {
75-
k: DatabaseInfo(**v)
76-
for k, v in self.databases.items()
77-
if isinstance(v, dict)
78-
}
79-
self.engines = {
80-
k: EngineInfo(**v) for k, v in self.engines.items() if isinstance(v, dict)
81-
}
74+
75+
# Convert dict values to dataclasses, keep existing dataclass objects
76+
new_databases = {}
77+
for k, db in self.databases.items():
78+
if isinstance(db, dict):
79+
new_databases[k] = DatabaseInfo(**db)
80+
else:
81+
new_databases[k] = db
82+
self.databases = new_databases
83+
84+
# Convert dict values to dataclasses, keep existing dataclass objects
85+
new_engines = {}
86+
for k, engine in self.engines.items():
87+
if isinstance(engine, dict):
88+
new_engines[k] = EngineInfo(**engine)
89+
else:
90+
new_engines[k] = engine
91+
self.engines = new_engines
8292

8393

8494
def noop_if_disabled(func: Callable) -> Callable:

0 commit comments

Comments
 (0)