Skip to content

Commit 175d4d1

Browse files
Add Firebolt Core connection support
- Add FireboltCore import from firebolt.client.auth - Modify create_connect_args to detect 'url' parameter for Core connections - Update _determine_auth to return FireboltCore auth when is_core=True - Skip client_id/secret and account_name validation for Core connections - Pass Core URL to SDK via kwargs['url'] - Add comprehensive unit tests for Core connection behavior - Maintain backward compatibility with existing connection patterns Co-Authored-By: [email protected] <[email protected]>
1 parent fd01b3c commit 175d4d1

File tree

2 files changed

+58
-4
lines changed

2 files changed

+58
-4
lines changed

src/firebolt_db/firebolt_dialect.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import firebolt.db as dbapi
77
import sqlalchemy.types as sqltypes
8-
from firebolt.client.auth import Auth, ClientCredentials, UsernamePassword
8+
from firebolt.client.auth import Auth, ClientCredentials, UsernamePassword, FireboltCore
99
from firebolt.db import Cursor
1010
from sqlalchemy.engine import Connection as AlchemyConnection
1111
from sqlalchemy.engine import ExecutionContext, default
@@ -145,18 +145,32 @@ def create_connect_args(self, url: URL) -> Tuple[List, Dict]:
145145
"""
146146
Build firebolt-sdk compatible connection arguments.
147147
URL format : firebolt://id:secret@host:port/db_name
148+
For Core: firebolt://host:port/db_name?url=core_url
148149
"""
149150
parameters = dict(url.query)
151+
152+
is_core_connection = "url" in parameters
153+
core_url = parameters.pop("url", None) if is_core_connection else None
154+
150155
# parameters are all passed as a string, we need to convert
151156
# bool flag to boolean for SDK compatibility
152157
token_cache_flag = bool(strtobool(parameters.pop("use_token_cache", "True")))
153-
auth = _determine_auth(url.username, url.password, token_cache_flag)
158+
159+
if is_core_connection:
160+
auth = _determine_auth("", "", token_cache_flag, is_core=True)
161+
else:
162+
auth = _determine_auth(url.username, url.password, token_cache_flag)
163+
154164
kwargs: Dict[str, Union[str, Auth, Dict[str, Any], None]] = {
155165
"database": url.host or None,
156166
"auth": auth,
157167
"engine_name": url.database,
158168
"additional_parameters": {},
159169
}
170+
171+
if core_url:
172+
kwargs["url"] = core_url
173+
160174
additional_parameters = {}
161175
if "account_name" in parameters:
162176
kwargs["account_name"] = parameters.pop("account_name")
@@ -366,8 +380,10 @@ def get_is_nullable(column_is_nullable: int) -> bool:
366380
return column_is_nullable == 1
367381

368382

369-
def _determine_auth(key: str, secret: str, token_cache_flag: bool = True) -> Auth:
370-
if "@" in key:
383+
def _determine_auth(key: str, secret: str, token_cache_flag: bool = True, is_core: bool = False) -> Auth:
384+
if is_core:
385+
return FireboltCore()
386+
elif "@" in key:
371387
return UsernamePassword(key, secret, token_cache_flag)
372388
else:
373389
return ClientCredentials(key, secret, token_cache_flag)

tests/unit/test_firebolt_dialect.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from firebolt_db.firebolt_dialect import dialect as dialect_definition
2020
from firebolt_db.firebolt_dialect import resolve_type
21+
from firebolt.client.auth import FireboltCore
2122

2223

2324
class TestFireboltDialect:
@@ -302,6 +303,43 @@ def test_unicode_description(
302303
):
303304
assert dialect._check_unicode_description(connection)
304305

306+
def test_create_connect_args_core_connection(self, dialect: FireboltDialect):
307+
connection_url = (
308+
"test_engine://test_db_name/test_engine_name?"
309+
"url=http://localhost:8080"
310+
)
311+
u = url.make_url(connection_url)
312+
result_list, result_dict = dialect.create_connect_args(u)
313+
314+
assert result_dict["engine_name"] == "test_engine_name"
315+
assert result_dict["database"] == "test_db_name"
316+
assert result_dict["url"] == "http://localhost:8080"
317+
assert isinstance(result_dict["auth"], FireboltCore)
318+
assert "account_name" not in result_dict
319+
assert result_list == []
320+
321+
def test_create_connect_args_core_connection_with_database(self, dialect: FireboltDialect):
322+
connection_url = (
323+
"test_engine://test_db_name?"
324+
"url=http://localhost:8080"
325+
)
326+
u = url.make_url(connection_url)
327+
result_list, result_dict = dialect.create_connect_args(u)
328+
329+
assert result_dict["database"] == "test_db_name"
330+
assert result_dict["url"] == "http://localhost:8080"
331+
assert isinstance(result_dict["auth"], FireboltCore)
332+
assert result_dict["engine_name"] is None
333+
assert result_list == []
334+
335+
def test_create_connect_args_core_no_credentials_required(self, dialect: FireboltDialect):
336+
connection_url = "test_engine://test_db_name?url=http://localhost:8080"
337+
u = url.make_url(connection_url)
338+
339+
result_list, result_dict = dialect.create_connect_args(u)
340+
assert isinstance(result_dict["auth"], FireboltCore)
341+
assert "account_name" not in result_dict
342+
305343

306344
def test_get_is_nullable():
307345
assert firebolt_db.firebolt_dialect.get_is_nullable(1)

0 commit comments

Comments
 (0)