33from ssl import SSLContext
44from types import TracebackType
55from typing import Any , Dict , List , Optional , Type , Union
6+ from uuid import uuid4
67
7- from httpx import Timeout
8+ from httpx import Timeout , codes
89
910from firebolt .async_db .cursor import Cursor , CursorV1 , CursorV2
10- from firebolt .async_db .util import _get_system_engine_url_and_params
1111from firebolt .client import DEFAULT_API_URL
1212from firebolt .client .auth import Auth
1313from firebolt .client .auth .base import FireboltAuthVersion
2020 AsyncQueryInfo ,
2121 BaseConnection ,
2222 _parse_async_query_info_results ,
23+ get_cached_system_engine_info ,
24+ get_user_agent_for_connection ,
25+ set_cached_system_engine_info ,
2326)
24- from firebolt .common .cache import _firebolt_system_engine_cache
2527from firebolt .common .constants import DEFAULT_TIMEOUT_SECONDS
28+ from firebolt .utils .cache import EngineInfo
2629from firebolt .utils .exception import (
30+ AccountNotFoundOrNoAccessError ,
2731 ConfigurationError ,
2832 ConnectionClosedError ,
2933 FireboltError ,
34+ InterfaceError ,
3035)
3136from firebolt .utils .firebolt_core import (
3237 get_core_certificate_context ,
3338 parse_firebolt_core_url ,
3439 validate_firebolt_core_parameters ,
3540)
36- from firebolt .utils .usage_tracker import get_user_agent_header
37- from firebolt .utils .util import fix_url_schema , validate_engine_name_and_url_v1
41+ from firebolt .utils .urls import GATEWAY_HOST_BY_ACCOUNT_NAME
42+ from firebolt .utils .util import (
43+ fix_url_schema ,
44+ parse_url_and_params ,
45+ validate_engine_name_and_url_v1 ,
46+ )
3847
3948
4049class Connection (BaseConnection ):
@@ -71,6 +80,7 @@ class Connection(BaseConnection):
7180 "_is_closed" ,
7281 "client_class" ,
7382 "cursor_type" ,
83+ "id" ,
7484 )
7585
7686 def __init__ (
@@ -81,12 +91,14 @@ def __init__(
8191 cursor_type : Type [Cursor ],
8292 api_endpoint : str ,
8393 init_parameters : Optional [Dict [str , Any ]] = None ,
94+ id : str = uuid4 ().hex ,
8495 ):
8596 super ().__init__ (cursor_type )
8697 self .api_endpoint = api_endpoint
8798 self .engine_url = engine_url
8899 self ._cursors : List [Cursor ] = []
89100 self ._client = client
101+ self .id = id
90102 self .init_parameters = init_parameters or {}
91103 if database :
92104 self .init_parameters ["database" ] = database
@@ -225,13 +237,13 @@ async def connect(
225237 if not auth :
226238 raise ConfigurationError ("auth is required to connect." )
227239
240+ api_endpoint = fix_url_schema (api_endpoint )
228241 # Type checks
229242 assert auth is not None
230- user_drivers = additional_parameters .get ("user_drivers" , [])
231- user_clients = additional_parameters .get ("user_clients" , [])
232- user_agent_header = get_user_agent_header (user_drivers , user_clients )
233- if disable_cache :
234- _firebolt_system_engine_cache .disable ()
243+ connection_id = uuid4 ().hex
244+ user_agent_header = get_user_agent_for_connection (
245+ auth , connection_id , account_name , additional_parameters , disable_cache
246+ )
235247 # Use CORE if auth is FireboltCore
236248 # Use V2 if auth is ClientCredentials
237249 # Use V1 if auth is ServiceAccount or UsernamePassword
@@ -254,6 +266,8 @@ async def connect(
254266 database = database ,
255267 engine_name = engine_name ,
256268 api_endpoint = api_endpoint ,
269+ connection_id = connection_id ,
270+ disable_cache = disable_cache ,
257271 )
258272 elif auth_version == FireboltAuthVersion .V1 :
259273 return await connect_v1 (
@@ -264,6 +278,7 @@ async def connect(
264278 engine_name = engine_name ,
265279 engine_url = engine_url ,
266280 api_endpoint = api_endpoint ,
281+ connection_id = connection_id ,
267282 )
268283 else :
269284 raise ConfigurationError (f"Unsupported auth type: { type (auth )} " )
@@ -272,10 +287,12 @@ async def connect(
272287async def connect_v2 (
273288 auth : Auth ,
274289 user_agent_header : str ,
290+ connection_id : str ,
275291 account_name : Optional [str ] = None ,
276292 database : Optional [str ] = None ,
277293 engine_name : Optional [str ] = None ,
278294 api_endpoint : str = DEFAULT_API_URL ,
295+ disable_cache : bool = False ,
279296) -> Connection :
280297 """Connect to Firebolt.
281298
@@ -301,10 +318,8 @@ async def connect_v2(
301318 assert auth is not None
302319 assert account_name is not None
303320
304- api_endpoint = fix_url_schema (api_endpoint )
305-
306- system_engine_url , system_engine_params = await _get_system_engine_url_and_params (
307- auth , account_name , api_endpoint
321+ system_engine_info = await _get_system_engine_url_and_params (
322+ auth , account_name , api_endpoint , connection_id , disable_cache
308323 )
309324
310325 client = AsyncClientV2 (
@@ -316,19 +331,21 @@ async def connect_v2(
316331 )
317332
318333 async with Connection (
319- system_engine_url ,
334+ system_engine_info . url ,
320335 None ,
321336 client ,
322337 CursorV2 ,
323338 api_endpoint ,
324- system_engine_params ,
339+ system_engine_info .params ,
340+ connection_id ,
325341 ) as system_engine_connection :
326342
327343 cursor = system_engine_connection .cursor ()
344+
328345 if database :
329- await cursor .execute ( f'USE DATABASE " { database } "' )
346+ await cursor .use_database ( database , cache = not disable_cache )
330347 if engine_name :
331- await cursor .execute ( f'USE ENGINE " { engine_name } "' )
348+ await cursor .use_engine ( engine_name , cache = not disable_cache )
332349 # Ensure cursors created from this connection are using the same starting
333350 # database and engine
334351 return Connection (
@@ -338,12 +355,14 @@ async def connect_v2(
338355 CursorV2 ,
339356 api_endpoint ,
340357 cursor .parameters ,
358+ connection_id ,
341359 )
342360
343361
344362async def connect_v1 (
345363 auth : Auth ,
346364 user_agent_header : str ,
365+ connection_id : str ,
347366 database : Optional [str ] = None ,
348367 account_name : Optional [str ] = None ,
349368 engine_name : Optional [str ] = None ,
@@ -358,8 +377,6 @@ async def connect_v1(
358377
359378 validate_engine_name_and_url_v1 (engine_name , engine_url )
360379
361- api_endpoint = fix_url_schema (api_endpoint )
362-
363380 no_engine_client = AsyncClientV1 (
364381 auth = auth ,
365382 base_url = api_endpoint ,
@@ -397,11 +414,7 @@ async def connect_v1(
397414 headers = {"User-Agent" : user_agent_header },
398415 )
399416 return Connection (
400- engine_url ,
401- database ,
402- client ,
403- CursorV1 ,
404- api_endpoint ,
417+ engine_url , database , client , CursorV1 , api_endpoint , id = connection_id
405418 )
406419
407420
@@ -448,3 +461,39 @@ def connect_core(
448461 cursor_type = CursorV2 ,
449462 api_endpoint = verified_url ,
450463 )
464+
465+
466+ async def _get_system_engine_url_and_params (
467+ auth : Auth ,
468+ account_name : str ,
469+ api_endpoint : str ,
470+ connection_id : str ,
471+ disable_cache : bool = False ,
472+ ) -> EngineInfo :
473+ cache_key , cached_result = get_cached_system_engine_info (
474+ auth , account_name , disable_cache
475+ )
476+ if cached_result :
477+ return cached_result
478+
479+ async with AsyncClientV2 (
480+ auth = auth ,
481+ base_url = api_endpoint ,
482+ account_name = account_name ,
483+ api_endpoint = api_endpoint ,
484+ timeout = Timeout (DEFAULT_TIMEOUT_SECONDS ),
485+ ) as client :
486+ url = GATEWAY_HOST_BY_ACCOUNT_NAME .format (account_name = account_name )
487+ response = await client .get (url = url )
488+ if response .status_code == codes .NOT_FOUND :
489+ raise AccountNotFoundOrNoAccessError (account_name )
490+ if response .status_code != codes .OK :
491+ raise InterfaceError (
492+ f"Unable to retrieve system engine endpoint { url } : "
493+ f"{ response .status_code } { response .content .decode ()} "
494+ )
495+ url , params = parse_url_and_params (response .json ()["engineUrl" ])
496+
497+ return set_cached_system_engine_info (
498+ cache_key , connection_id , url , params , disable_cache
499+ )
0 commit comments