Skip to content

Commit fcab0fc

Browse files
authored
feat(FIR-46254): Extend caching (#453)
1 parent 48a95c0 commit fcab0fc

28 files changed

+2329
-311
lines changed

src/firebolt/async_db/connection.py

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
from ssl import SSLContext
44
from types import TracebackType
55
from typing import Any, Dict, List, Optional, Type, Union
6+
from uuid import uuid4
67

7-
from httpx import Timeout
8+
from httpx import Timeout, codes
89

910
from firebolt.async_db.cursor import Cursor, CursorV1, CursorV2
10-
from firebolt.async_db.util import _get_system_engine_url_and_params
1111
from firebolt.client import DEFAULT_API_URL
1212
from firebolt.client.auth import Auth
1313
from firebolt.client.auth.base import FireboltAuthVersion
@@ -20,21 +20,30 @@
2020
AsyncQueryInfo,
2121
BaseConnection,
2222
_parse_async_query_info_results,
23+
get_cached_system_engine_info,
24+
get_user_agent_for_connection,
25+
set_cached_system_engine_info,
2326
)
24-
from firebolt.common.cache import _firebolt_system_engine_cache
2527
from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS
28+
from firebolt.utils.cache import EngineInfo
2629
from firebolt.utils.exception import (
30+
AccountNotFoundOrNoAccessError,
2731
ConfigurationError,
2832
ConnectionClosedError,
2933
FireboltError,
34+
InterfaceError,
3035
)
3136
from firebolt.utils.firebolt_core import (
3237
get_core_certificate_context,
3338
parse_firebolt_core_url,
3439
validate_firebolt_core_parameters,
3540
)
36-
from firebolt.utils.usage_tracker import get_user_agent_header
37-
from firebolt.utils.util import fix_url_schema, validate_engine_name_and_url_v1
41+
from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME
42+
from firebolt.utils.util import (
43+
fix_url_schema,
44+
parse_url_and_params,
45+
validate_engine_name_and_url_v1,
46+
)
3847

3948

4049
class Connection(BaseConnection):
@@ -71,6 +80,7 @@ class Connection(BaseConnection):
7180
"_is_closed",
7281
"client_class",
7382
"cursor_type",
83+
"id",
7484
)
7585

7686
def __init__(
@@ -81,12 +91,14 @@ def __init__(
8191
cursor_type: Type[Cursor],
8292
api_endpoint: str,
8393
init_parameters: Optional[Dict[str, Any]] = None,
94+
id: str = uuid4().hex,
8495
):
8596
super().__init__(cursor_type)
8697
self.api_endpoint = api_endpoint
8798
self.engine_url = engine_url
8899
self._cursors: List[Cursor] = []
89100
self._client = client
101+
self.id = id
90102
self.init_parameters = init_parameters or {}
91103
if database:
92104
self.init_parameters["database"] = database
@@ -225,13 +237,13 @@ async def connect(
225237
if not auth:
226238
raise ConfigurationError("auth is required to connect.")
227239

240+
api_endpoint = fix_url_schema(api_endpoint)
228241
# Type checks
229242
assert auth is not None
230-
user_drivers = additional_parameters.get("user_drivers", [])
231-
user_clients = additional_parameters.get("user_clients", [])
232-
user_agent_header = get_user_agent_header(user_drivers, user_clients)
233-
if disable_cache:
234-
_firebolt_system_engine_cache.disable()
243+
connection_id = uuid4().hex
244+
user_agent_header = get_user_agent_for_connection(
245+
auth, connection_id, account_name, additional_parameters, disable_cache
246+
)
235247
# Use CORE if auth is FireboltCore
236248
# Use V2 if auth is ClientCredentials
237249
# Use V1 if auth is ServiceAccount or UsernamePassword
@@ -254,6 +266,8 @@ async def connect(
254266
database=database,
255267
engine_name=engine_name,
256268
api_endpoint=api_endpoint,
269+
connection_id=connection_id,
270+
disable_cache=disable_cache,
257271
)
258272
elif auth_version == FireboltAuthVersion.V1:
259273
return await connect_v1(
@@ -264,6 +278,7 @@ async def connect(
264278
engine_name=engine_name,
265279
engine_url=engine_url,
266280
api_endpoint=api_endpoint,
281+
connection_id=connection_id,
267282
)
268283
else:
269284
raise ConfigurationError(f"Unsupported auth type: {type(auth)}")
@@ -272,10 +287,12 @@ async def connect(
272287
async def connect_v2(
273288
auth: Auth,
274289
user_agent_header: str,
290+
connection_id: str,
275291
account_name: Optional[str] = None,
276292
database: Optional[str] = None,
277293
engine_name: Optional[str] = None,
278294
api_endpoint: str = DEFAULT_API_URL,
295+
disable_cache: bool = False,
279296
) -> Connection:
280297
"""Connect to Firebolt.
281298
@@ -301,10 +318,8 @@ async def connect_v2(
301318
assert auth is not None
302319
assert account_name is not None
303320

304-
api_endpoint = fix_url_schema(api_endpoint)
305-
306-
system_engine_url, system_engine_params = await _get_system_engine_url_and_params(
307-
auth, account_name, api_endpoint
321+
system_engine_info = await _get_system_engine_url_and_params(
322+
auth, account_name, api_endpoint, connection_id, disable_cache
308323
)
309324

310325
client = AsyncClientV2(
@@ -316,19 +331,21 @@ async def connect_v2(
316331
)
317332

318333
async with Connection(
319-
system_engine_url,
334+
system_engine_info.url,
320335
None,
321336
client,
322337
CursorV2,
323338
api_endpoint,
324-
system_engine_params,
339+
system_engine_info.params,
340+
connection_id,
325341
) as system_engine_connection:
326342

327343
cursor = system_engine_connection.cursor()
344+
328345
if database:
329-
await cursor.execute(f'USE DATABASE "{database}"')
346+
await cursor.use_database(database, cache=not disable_cache)
330347
if engine_name:
331-
await cursor.execute(f'USE ENGINE "{engine_name}"')
348+
await cursor.use_engine(engine_name, cache=not disable_cache)
332349
# Ensure cursors created from this connection are using the same starting
333350
# database and engine
334351
return Connection(
@@ -338,12 +355,14 @@ async def connect_v2(
338355
CursorV2,
339356
api_endpoint,
340357
cursor.parameters,
358+
connection_id,
341359
)
342360

343361

344362
async def connect_v1(
345363
auth: Auth,
346364
user_agent_header: str,
365+
connection_id: str,
347366
database: Optional[str] = None,
348367
account_name: Optional[str] = None,
349368
engine_name: Optional[str] = None,
@@ -358,8 +377,6 @@ async def connect_v1(
358377

359378
validate_engine_name_and_url_v1(engine_name, engine_url)
360379

361-
api_endpoint = fix_url_schema(api_endpoint)
362-
363380
no_engine_client = AsyncClientV1(
364381
auth=auth,
365382
base_url=api_endpoint,
@@ -397,11 +414,7 @@ async def connect_v1(
397414
headers={"User-Agent": user_agent_header},
398415
)
399416
return Connection(
400-
engine_url,
401-
database,
402-
client,
403-
CursorV1,
404-
api_endpoint,
417+
engine_url, database, client, CursorV1, api_endpoint, id=connection_id
405418
)
406419

407420

@@ -448,3 +461,39 @@ def connect_core(
448461
cursor_type=CursorV2,
449462
api_endpoint=verified_url,
450463
)
464+
465+
466+
async def _get_system_engine_url_and_params(
467+
auth: Auth,
468+
account_name: str,
469+
api_endpoint: str,
470+
connection_id: str,
471+
disable_cache: bool = False,
472+
) -> EngineInfo:
473+
cache_key, cached_result = get_cached_system_engine_info(
474+
auth, account_name, disable_cache
475+
)
476+
if cached_result:
477+
return cached_result
478+
479+
async with AsyncClientV2(
480+
auth=auth,
481+
base_url=api_endpoint,
482+
account_name=account_name,
483+
api_endpoint=api_endpoint,
484+
timeout=Timeout(DEFAULT_TIMEOUT_SECONDS),
485+
) as client:
486+
url = GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name)
487+
response = await client.get(url=url)
488+
if response.status_code == codes.NOT_FOUND:
489+
raise AccountNotFoundOrNoAccessError(account_name)
490+
if response.status_code != codes.OK:
491+
raise InterfaceError(
492+
f"Unable to retrieve system engine endpoint {url}: "
493+
f"{response.status_code} {response.content.decode()}"
494+
)
495+
url, params = parse_url_and_params(response.json()["engineUrl"])
496+
497+
return set_cached_system_engine_info(
498+
cache_key, connection_id, url, params, disable_cache
499+
)

src/firebolt/async_db/cursor.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from firebolt.async_db.connection import Connection
5959

6060
from firebolt.utils.async_util import anext, async_islice
61+
from firebolt.utils.cache import ConnectionInfo, DatabaseInfo, EngineInfo
6162
from firebolt.utils.util import Timer, raise_error_from_response
6263

6364
logger = logging.getLogger(__name__)
@@ -85,8 +86,8 @@ def __init__(
8586
**kwargs: Any,
8687
) -> None:
8788
super().__init__(*args, **kwargs)
88-
self._client = client
8989
self.connection = connection
90+
self._client: AsyncClient = client
9091
self.engine_url = connection.engine_url
9192
self._row_set: Optional[BaseAsyncRowSet] = None
9293
if connection.init_parameters:
@@ -332,6 +333,41 @@ async def _handle_query_execution(
332333
await self._parse_response_headers(resp.headers)
333334
await self._append_row_set_from_response(resp)
334335

336+
async def use_database(self, database: str, cache: bool = True) -> None:
337+
"""Switch the current database context with caching."""
338+
if cache:
339+
cache_record = self.get_cache_record()
340+
cache_record = (
341+
cache_record if cache_record else ConnectionInfo(id=self.connection.id)
342+
)
343+
if cache_record.databases.get(database):
344+
# If database is cached, use it
345+
self.database = database
346+
else:
347+
await self.execute(f'USE DATABASE "{database}"')
348+
cache_record.databases[database] = DatabaseInfo(database)
349+
self.set_cache_record(cache_record)
350+
else:
351+
await self.execute(f'USE DATABASE "{database}"')
352+
353+
async def use_engine(self, engine: str, cache: bool = True) -> None:
354+
"""Switch the current engine context with caching."""
355+
if cache:
356+
cache_obj = self.get_cache_record()
357+
cache_obj = (
358+
cache_obj if cache_obj else ConnectionInfo(id=self.connection.id)
359+
)
360+
if cache_obj.engines.get(engine):
361+
# If engine is cached, use it
362+
self.engine_url = cache_obj.engines[engine].url
363+
self._update_set_parameters(cache_obj.engines[engine].params)
364+
else:
365+
await self.execute(f'USE ENGINE "{engine}"')
366+
cache_obj.engines[engine] = EngineInfo(self.engine_url, self.parameters)
367+
self.set_cache_record(cache_obj)
368+
else:
369+
await self.execute(f'USE ENGINE "{engine}"')
370+
335371
@check_not_closed
336372
async def execute(
337373
self,

src/firebolt/async_db/util.py

Lines changed: 0 additions & 46 deletions
This file was deleted.

src/firebolt/client/auth/base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,24 @@ def token(self) -> Optional[str]:
6565
"""
6666
return self._token
6767

68+
@property
69+
@abstractmethod
70+
def principal(self) -> str:
71+
"""Get the principal (username or id) associated with the token.
72+
73+
Returns:
74+
str: Principal string
75+
"""
76+
77+
@property
78+
@abstractmethod
79+
def secret(self) -> str:
80+
"""Get the secret (password or secret key) associated with the token.
81+
82+
Returns:
83+
str: Secret string
84+
"""
85+
6886
@abstractmethod
6987
def get_firebolt_version(self) -> FireboltAuthVersion:
7088
"""Get Firebolt version from auth.

src/firebolt/client/auth/client_credentials.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,24 @@ def copy(self) -> "ClientCredentials":
5353
self.client_id, self.client_secret, self._use_token_cache
5454
)
5555

56+
@property
57+
def principal(self) -> str:
58+
"""Get the principal (client id) associated with this auth.
59+
60+
Returns:
61+
str: Principal client id
62+
"""
63+
return self.client_id
64+
65+
@property
66+
def secret(self) -> str:
67+
"""Get the secret (secret key) associated with this auth.
68+
69+
Returns:
70+
str: Secret
71+
"""
72+
return self.client_secret
73+
5674
def get_firebolt_version(self) -> FireboltAuthVersion:
5775
"""Get Firebolt version from auth.
5876

0 commit comments

Comments
 (0)