|
5 | 5 |
|
6 | 6 | import firebolt.db as dbapi |
7 | 7 | import sqlalchemy.types as sqltypes |
8 | | -from firebolt.client.auth import Auth, ClientCredentials, UsernamePassword |
| 8 | +from firebolt.client.auth import ( |
| 9 | + Auth, |
| 10 | + ClientCredentials, |
| 11 | + FireboltCore, |
| 12 | + UsernamePassword, |
| 13 | +) |
9 | 14 | from firebolt.db import Cursor |
10 | 15 | from sqlalchemy.engine import Connection as AlchemyConnection |
11 | 16 | from sqlalchemy.engine import ExecutionContext, default |
@@ -145,36 +150,95 @@ def create_connect_args(self, url: URL) -> Tuple[List, Dict]: |
145 | 150 | """ |
146 | 151 | Build firebolt-sdk compatible connection arguments. |
147 | 152 | URL format : firebolt://id:secret@host:port/db_name |
| 153 | + For Core: firebolt://db_name?url=http://localhost:8080 |
| 154 | + (full URL including scheme, host, port in url parameter) |
148 | 155 | """ |
149 | 156 | parameters = dict(url.query) |
150 | | - # parameters are all passed as a string, we need to convert |
151 | | - # bool flag to boolean for SDK compatibility |
152 | | - token_cache_flag = bool(strtobool(parameters.pop("use_token_cache", "True"))) |
153 | | - auth = _determine_auth(url.username, url.password, token_cache_flag) |
| 157 | + is_core_connection = "url" in parameters |
| 158 | + |
| 159 | + if is_core_connection: |
| 160 | + self._validate_core_connection(url, parameters) |
| 161 | + |
| 162 | + token_cache_flag = self._parse_token_cache_flag(parameters) |
| 163 | + auth = _determine_auth(url, token_cache_flag) |
| 164 | + kwargs = self._build_connection_kwargs( |
| 165 | + url, parameters, auth, is_core_connection |
| 166 | + ) |
| 167 | + |
| 168 | + return ([], kwargs) |
| 169 | + |
| 170 | + def _validate_core_connection(self, url: URL, parameters: Dict[str, str]) -> None: |
| 171 | + """Validate that Core connection parameters are correct. |
| 172 | +
|
| 173 | + Only validates credentials since FireboltCore auth handles other parameters. |
| 174 | + """ |
| 175 | + if url.username or url.password: |
| 176 | + raise ArgumentError( |
| 177 | + "Core connections do not support username/password authentication" |
| 178 | + ) |
| 179 | + |
| 180 | + def _parse_token_cache_flag(self, parameters: Dict[str, str]) -> bool: |
| 181 | + """Parse and remove token cache flag from parameters.""" |
| 182 | + return bool(strtobool(parameters.pop("use_token_cache", "True"))) |
| 183 | + |
| 184 | + def _build_connection_kwargs( |
| 185 | + self, url: URL, parameters: Dict[str, str], auth: Auth, is_core_connection: bool |
| 186 | + ) -> Dict[str, Union[str, Auth, Dict[str, Any], None]]: |
| 187 | + """Build connection kwargs for the SDK. |
| 188 | +
|
| 189 | + SQLAlchemy URL mapping: |
| 190 | + - url.host -> database (Firebolt database name) |
| 191 | + - url.database -> engine_name (Firebolt engine name) |
| 192 | + """ |
154 | 193 | kwargs: Dict[str, Union[str, Auth, Dict[str, Any], None]] = { |
155 | 194 | "database": url.host or None, |
156 | 195 | "auth": auth, |
157 | 196 | "engine_name": url.database, |
158 | 197 | "additional_parameters": {}, |
159 | 198 | } |
160 | | - additional_parameters = {} |
| 199 | + |
| 200 | + if is_core_connection: |
| 201 | + kwargs["url"] = parameters.pop("url") |
| 202 | + |
| 203 | + self._handle_account_name(parameters, auth, kwargs) |
| 204 | + self._handle_environment_config(kwargs) |
| 205 | + kwargs["additional_parameters"] = self._build_additional_parameters(parameters) |
| 206 | + self._set_parameters = parameters |
| 207 | + |
| 208 | + return kwargs |
| 209 | + |
| 210 | + def _handle_account_name( |
| 211 | + self, |
| 212 | + parameters: Dict[str, str], |
| 213 | + auth: Auth, |
| 214 | + kwargs: Dict[str, Union[str, Auth, Dict[str, Any], None]], |
| 215 | + ) -> None: |
| 216 | + """Handle account_name parameter and validation.""" |
161 | 217 | if "account_name" in parameters: |
162 | 218 | kwargs["account_name"] = parameters.pop("account_name") |
163 | 219 | elif isinstance(auth, ClientCredentials): |
164 | | - # account_name is required for client credentials authentication |
165 | 220 | raise ArgumentError( |
166 | 221 | "account_name parameter must be provided to authenticate" |
167 | 222 | ) |
168 | | - self._set_parameters = parameters |
169 | | - # If URL override is not provided leave it to the sdk to determine the endpoint |
| 223 | + |
| 224 | + def _handle_environment_config( |
| 225 | + self, kwargs: Dict[str, Union[str, Auth, Dict[str, Any], None]] |
| 226 | + ) -> None: |
| 227 | + """Handle environment-based configuration.""" |
170 | 228 | if "FIREBOLT_BASE_URL" in os.environ: |
171 | 229 | kwargs["api_endpoint"] = os.environ["FIREBOLT_BASE_URL"] |
172 | | - # Tracking information |
| 230 | + |
| 231 | + def _build_additional_parameters( |
| 232 | + self, parameters: Dict[str, str] |
| 233 | + ) -> Dict[str, Any]: |
| 234 | + """Build additional parameters including tracking information.""" |
| 235 | + additional_parameters: Dict[str, Any] = {} |
| 236 | + |
173 | 237 | if "user_clients" in parameters or "user_drivers" in parameters: |
174 | 238 | additional_parameters["user_drivers"] = parameters.pop("user_drivers", []) |
175 | 239 | additional_parameters["user_clients"] = parameters.pop("user_clients", []) |
176 | | - kwargs["additional_parameters"] = additional_parameters |
177 | | - return ([], kwargs) |
| 240 | + |
| 241 | + return additional_parameters |
178 | 242 |
|
179 | 243 | def get_schema_names( |
180 | 244 | self, connection: AlchemyConnection, **kwargs: Any |
@@ -366,8 +430,13 @@ def get_is_nullable(column_is_nullable: int) -> bool: |
366 | 430 | return column_is_nullable == 1 |
367 | 431 |
|
368 | 432 |
|
369 | | -def _determine_auth(key: str, secret: str, token_cache_flag: bool = True) -> Auth: |
370 | | - if "@" in key: |
371 | | - return UsernamePassword(key, secret, token_cache_flag) |
| 433 | +def _determine_auth(url: URL, token_cache_flag: bool = True) -> Auth: |
| 434 | + parameters = dict(url.query) |
| 435 | + is_core_connection = "url" in parameters |
| 436 | + |
| 437 | + if is_core_connection: |
| 438 | + return FireboltCore() |
| 439 | + elif "@" in (url.username or ""): |
| 440 | + return UsernamePassword(url.username, url.password, token_cache_flag) |
372 | 441 | else: |
373 | | - return ClientCredentials(key, secret, token_cache_flag) |
| 442 | + return ClientCredentials(url.username, url.password, token_cache_flag) |
0 commit comments