|
1 | 1 | from abc import abstractmethod |
2 | | -from typing import Literal, TypedDict |
| 2 | +from typing import Any, Literal, TypedDict |
| 3 | +from urllib.parse import parse_qs, urlparse |
3 | 4 |
|
4 | 5 | from testgen.common.encrypt import DecryptText |
5 | 6 |
|
@@ -37,19 +38,21 @@ class FlavorService: |
37 | 38 | private_key_passphrase = None |
38 | 39 | http_path = None |
39 | 40 | catalog = None |
| 41 | + warehouse = None |
40 | 42 |
|
41 | 43 | def init(self, connection_params: ConnectionParams): |
42 | | - self.url = connection_params.get("url", None) |
| 44 | + self.url = connection_params.get("url") or "" |
43 | 45 | self.connect_by_url = connection_params.get("connect_by_url", False) |
44 | | - self.username = connection_params.get("project_user") |
45 | | - self.host = connection_params.get("project_host") |
46 | | - self.port = connection_params.get("project_port") |
47 | | - self.dbname = connection_params.get("project_db") |
| 46 | + self.username = connection_params.get("project_user") or "" |
| 47 | + self.host = connection_params.get("project_host") or "" |
| 48 | + self.port = connection_params.get("project_port") or "" |
| 49 | + self.dbname = connection_params.get("project_db") or "" |
48 | 50 | self.flavor = connection_params.get("sql_flavor") |
49 | 51 | self.dbschema = connection_params.get("table_group_schema", None) |
50 | 52 | self.connect_by_key = connection_params.get("connect_by_key", False) |
51 | | - self.http_path = connection_params.get("http_path", None) |
52 | | - self.catalog = connection_params.get("catalog", None) |
| 53 | + self.http_path = connection_params.get("http_path") or "" |
| 54 | + self.catalog = connection_params.get("catalog") or "" |
| 55 | + self.warehouse = connection_params.get("warehouse") or "" |
53 | 56 |
|
54 | 57 | password = connection_params.get("project_pw_encrypted", None) |
55 | 58 | if isinstance(password, memoryview) or isinstance(password, bytes): |
@@ -90,3 +93,45 @@ def get_connection_string_from_fields(self) -> str: |
90 | 93 | @abstractmethod |
91 | 94 | def get_connection_string_head(self) -> str: |
92 | 95 | raise NotImplementedError("Subclasses must implement this method") |
| 96 | + |
| 97 | + def get_parts_from_connection_string(self) -> dict[str, Any]: |
| 98 | + if self.connect_by_url: |
| 99 | + if not self.url: |
| 100 | + return {} |
| 101 | + |
| 102 | + parsed_url = urlparse(self.get_connection_string()) |
| 103 | + credentials, location = ( |
| 104 | + parsed_url.netloc if "@" in parsed_url.netloc else f"@{parsed_url.netloc}" |
| 105 | + ).split("@") |
| 106 | + username, password = ( |
| 107 | + credentials if ":" in credentials else f"{credentials}:" |
| 108 | + ).split(":") |
| 109 | + host, port = ( |
| 110 | + location if ":" in location else f"{location}:" |
| 111 | + ).split(":") |
| 112 | + |
| 113 | + database = (path_patrs[0] if (path_patrs := parsed_url.path.strip("/").split("/")) else "") |
| 114 | + |
| 115 | + extras = { |
| 116 | + param_name: param_values[0] |
| 117 | + for param_name, param_values in parse_qs(parsed_url.query or "").items() |
| 118 | + } |
| 119 | + |
| 120 | + return { |
| 121 | + "username": username, |
| 122 | + "password": password, |
| 123 | + "host": host, |
| 124 | + "port": port, |
| 125 | + "dbname": database, |
| 126 | + **extras, |
| 127 | + } |
| 128 | + |
| 129 | + return { |
| 130 | + "username": self.username, |
| 131 | + "password": self.password, |
| 132 | + "host": self.host, |
| 133 | + "port": self.port, |
| 134 | + "dbname": self.dbname, |
| 135 | + "http_path": self.http_path, |
| 136 | + "catalog": self.catalog, |
| 137 | + } |
0 commit comments