Skip to content

Commit e949176

Browse files
authored
feat: add catalog type overrides (#5382)
1 parent 2462298 commit e949176

File tree

7 files changed

+145
-20
lines changed

7 files changed

+145
-20
lines changed

docs/integrations/engines/trino.md

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,19 +81,21 @@ hive.metastore.glue.default-warehouse-dir=s3://my-bucket/
8181

8282
### Connection options
8383

84-
| Option | Description | Type | Required |
85-
|----------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:--------:|
86-
| `type` | Engine type name - must be `trino` | string | Y |
87-
| `user` | The username (of the account) to log in to your cluster. When connecting to Starburst Galaxy clusters, you must include the role of the user as a suffix to the username. | string | Y |
88-
| `host` | The hostname of your cluster. Don't include the `http://` or `https://` prefix. | string | Y |
89-
| `catalog` | The name of a catalog in your cluster. | string | Y |
90-
| `http_scheme` | The HTTP scheme to use when connecting to your cluster. By default, it's `https` and can only be `http` for no-auth or basic auth. | string | N |
91-
| `port` | The port to connect to your cluster. By default, it's `443` for `https` scheme and `80` for `http` | int | N |
92-
| `roles` | Mapping of catalog name to a role | dict | N |
93-
| `http_headers` | Additional HTTP headers to send with each request. | dict | N |
94-
| `session_properties` | Trino session properties. Run `SHOW SESSION` to see all options. | dict | N |
95-
| `retries` | Number of retries to attempt when a request fails. Default: `3` | int | N |
96-
| `timezone` | Timezone to use for the connection. Default: client-side local timezone | string | N |
84+
| Option | Description | Type | Required |
85+
|---------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:--------:|
86+
| `type` | Engine type name - must be `trino` | string | Y |
87+
| `user` | The username (of the account) to log in to your cluster. When connecting to Starburst Galaxy clusters, you must include the role of the user as a suffix to the username. | string | Y |
88+
| `host` | The hostname of your cluster. Don't include the `http://` or `https://` prefix. | string | Y |
89+
| `catalog` | The name of a catalog in your cluster. | string | Y |
90+
| `http_scheme` | The HTTP scheme to use when connecting to your cluster. By default, it's `https` and can only be `http` for no-auth or basic auth. | string | N |
91+
| `port` | The port to connect to your cluster. By default, it's `443` for `https` scheme and `80` for `http` | int | N |
92+
| `roles` | Mapping of catalog name to a role | dict | N |
93+
| `http_headers` | Additional HTTP headers to send with each request. | dict | N |
94+
| `session_properties` | Trino session properties. Run `SHOW SESSION` to see all options. | dict | N |
95+
| `retries` | Number of retries to attempt when a request fails. Default: `3` | int | N |
96+
| `timezone` | Timezone to use for the connection. Default: client-side local timezone | string | N |
97+
| `schema_location_mapping` | A mapping of regex patterns to S3 locations to use for the `LOCATION` property when creating schemas. See [Table and Schema locations](#table-and-schema-locations) for more details. | dict | N |
98+
| `catalog_type_overrides` | A mapping of catalog names to their connector type. This is used to enable/disable connector specific behavior. See [Catalog Type Overrides](#catalog-type-overrides) for more details. | dict | N |
9799

98100
## Table and Schema locations
99101

@@ -204,6 +206,25 @@ SELECT ...
204206

205207
This will cause SQLMesh to set the specified `LOCATION` when issuing a `CREATE TABLE` statement.
206208

209+
## Catalog Type Overrides
210+
211+
SQLMesh attempts to determine the connector type of a catalog by querying the `system.metadata.catalogs` table and checking the `connector_name` column.
212+
It checks if the connector name is `hive` for Hive connector behavior or contains `iceberg` or `delta_lake` for Iceberg or Delta Lake connector behavior respectively.
213+
However, the connector name may not always be a reliable way to determine the connector type, for example when using a custom connector or a fork of an existing connector.
214+
To handle such cases, you can use the `catalog_type_overrides` connection property to explicitly specify the connector type for specific catalogs.
215+
For example, to specify that the `datalake` catalog is using the Iceberg connector and the `analytics` catalog is using the Hive connector, you can configure the connection as follows:
216+
217+
```yaml title="config.yaml"
218+
gateways:
219+
trino:
220+
connection:
221+
type: trino
222+
...
223+
catalog_type_overrides:
224+
datalake: iceberg
225+
analytics: hive
226+
```
227+
207228
## Authentication
208229

209230
=== "No Auth"

sqlmesh/core/config/connection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class ConnectionConfig(abc.ABC, BaseConfig):
101101
pre_ping: bool
102102
pretty_sql: bool = False
103103
schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None
104+
catalog_type_overrides: t.Optional[t.Dict[str, str]] = None
104105

105106
# Whether to share a single connection across threads or create a new connection per thread.
106107
shared_connection: t.ClassVar[bool] = False
@@ -176,6 +177,7 @@ def create_engine_adapter(
176177
pretty_sql=self.pretty_sql,
177178
shared_connection=self.shared_connection,
178179
schema_differ_overrides=self.schema_differ_overrides,
180+
catalog_type_overrides=self.catalog_type_overrides,
179181
**self._extra_engine_config,
180182
)
181183

sqlmesh/core/engine_adapter/base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,10 @@ def schema_differ(self) -> SchemaDiffer:
223223
}
224224
)
225225

226+
@property
227+
def _catalog_type_overrides(self) -> t.Dict[str, str]:
228+
return self._extra_config.get("catalog_type_overrides") or {}
229+
226230
@classmethod
227231
def _casted_columns(
228232
cls,
@@ -430,7 +434,11 @@ def get_catalog_type(self, catalog: t.Optional[str]) -> str:
430434
raise UnsupportedCatalogOperationError(
431435
f"{self.dialect} does not support catalogs and a catalog was provided: {catalog}"
432436
)
433-
return self.DEFAULT_CATALOG_TYPE
437+
return (
438+
self._catalog_type_overrides.get(catalog, self.DEFAULT_CATALOG_TYPE)
439+
if catalog
440+
else self.DEFAULT_CATALOG_TYPE
441+
)
434442

435443
def get_catalog_type_from_table(self, table: TableName) -> str:
436444
"""Get the catalog type from a table name if it has a catalog specified, otherwise return the current catalog type"""

sqlmesh/core/engine_adapter/trino.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class TrinoEngineAdapter(
7171
MAX_TIMESTAMP_PRECISION = 3
7272

7373
@property
74-
def schema_location_mapping(self) -> t.Optional[dict[re.Pattern, str]]:
74+
def schema_location_mapping(self) -> t.Optional[t.Dict[re.Pattern, str]]:
7575
return self._extra_config.get("schema_location_mapping")
7676

7777
@property
@@ -86,6 +86,8 @@ def set_current_catalog(self, catalog: str) -> None:
8686
def get_catalog_type(self, catalog: t.Optional[str]) -> str:
8787
row: t.Tuple = tuple()
8888
if catalog:
89+
if catalog_type_override := self._catalog_type_overrides.get(catalog):
90+
return catalog_type_override
8991
row = (
9092
self.fetchone(
9193
f"select connector_name from system.metadata.catalogs where catalog_name='{catalog}'"

tests/cli/test_cli.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,7 @@ def test_dlt_filesystem_pipeline(tmp_path):
957957
" # pre_ping: False\n"
958958
" # pretty_sql: False\n"
959959
" # schema_differ_overrides: \n"
960+
" # catalog_type_overrides: \n"
960961
" # aws_access_key_id: \n"
961962
" # aws_secret_access_key: \n"
962963
" # role_arn: \n"
@@ -1960,11 +1961,11 @@ def test_init_dbt_template(runner: CliRunner, tmp_path: Path):
19601961
@time_machine.travel(FREEZE_TIME)
19611962
def test_init_project_engine_configs(tmp_path):
19621963
engine_type_to_config = {
1963-
"redshift": "# concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # user: \n # password: \n # database: \n # host: \n # port: \n # source_address: \n # unix_sock: \n # ssl: \n # sslmode: \n # timeout: \n # tcp_keepalive: \n # application_name: \n # preferred_role: \n # principal_arn: \n # credentials_provider: \n # region: \n # cluster_identifier: \n # iam: \n # is_serverless: \n # serverless_acct_id: \n # serverless_work_group: \n # enable_merge: ",
1964-
"bigquery": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # method: oauth\n # project: \n # execution_project: \n # quota_project: \n # location: \n # keyfile: \n # keyfile_json: \n # token: \n # refresh_token: \n # client_id: \n # client_secret: \n # token_uri: \n # scopes: \n # impersonated_service_account: \n # job_creation_timeout_seconds: \n # job_execution_timeout_seconds: \n # job_retries: 1\n # job_retry_deadline_seconds: \n # priority: \n # maximum_bytes_billed: ",
1965-
"snowflake": "account: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # user: \n # password: \n # warehouse: \n # database: \n # role: \n # authenticator: \n # token: \n # host: \n # port: \n # application: Tobiko_SQLMesh\n # private_key: \n # private_key_path: \n # private_key_passphrase: \n # session_parameters: ",
1966-
"databricks": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # server_hostname: \n # http_path: \n # access_token: \n # auth_type: \n # oauth_client_id: \n # oauth_client_secret: \n # catalog: \n # http_headers: \n # session_configuration: \n # databricks_connect_server_hostname: \n # databricks_connect_access_token: \n # databricks_connect_cluster_id: \n # databricks_connect_use_serverless: False\n # force_databricks_connect: False\n # disable_databricks_connect: False\n # disable_spark_session: False",
1967-
"postgres": "host: \n user: \n password: \n port: \n database: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: True\n # pretty_sql: False\n # schema_differ_overrides: \n # keepalives_idle: \n # connect_timeout: 10\n # role: \n # sslmode: \n # application_name: ",
1964+
"redshift": "# concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # user: \n # password: \n # database: \n # host: \n # port: \n # source_address: \n # unix_sock: \n # ssl: \n # sslmode: \n # timeout: \n # tcp_keepalive: \n # application_name: \n # preferred_role: \n # principal_arn: \n # credentials_provider: \n # region: \n # cluster_identifier: \n # iam: \n # is_serverless: \n # serverless_acct_id: \n # serverless_work_group: \n # enable_merge: ",
1965+
"bigquery": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # method: oauth\n # project: \n # execution_project: \n # quota_project: \n # location: \n # keyfile: \n # keyfile_json: \n # token: \n # refresh_token: \n # client_id: \n # client_secret: \n # token_uri: \n # scopes: \n # impersonated_service_account: \n # job_creation_timeout_seconds: \n # job_execution_timeout_seconds: \n # job_retries: 1\n # job_retry_deadline_seconds: \n # priority: \n # maximum_bytes_billed: ",
1966+
"snowflake": "account: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # user: \n # password: \n # warehouse: \n # database: \n # role: \n # authenticator: \n # token: \n # host: \n # port: \n # application: Tobiko_SQLMesh\n # private_key: \n # private_key_path: \n # private_key_passphrase: \n # session_parameters: ",
1967+
"databricks": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # server_hostname: \n # http_path: \n # access_token: \n # auth_type: \n # oauth_client_id: \n # oauth_client_secret: \n # catalog: \n # http_headers: \n # session_configuration: \n # databricks_connect_server_hostname: \n # databricks_connect_access_token: \n # databricks_connect_cluster_id: \n # databricks_connect_use_serverless: False\n # force_databricks_connect: False\n # disable_databricks_connect: False\n # disable_spark_session: False",
1968+
"postgres": "host: \n user: \n password: \n port: \n database: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: True\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # keepalives_idle: \n # connect_timeout: 10\n # role: \n # sslmode: \n # application_name: ",
19681969
}
19691970

19701971
for engine_type, expected_config in engine_type_to_config.items():

tests/core/engine_adapter/test_trino.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sqlmesh.core.model import load_sql_based_model
1212
from sqlmesh.core.model.definition import SqlModel
1313
from sqlmesh.core.dialect import schema_
14+
from sqlmesh.utils.date import to_ds
1415
from sqlmesh.utils.errors import SQLMeshError
1516
from tests.core.engine_adapter import to_sql_calls
1617

@@ -683,3 +684,74 @@ def test_replace_table_catalog_support(
683684
sql_calls[0]
684685
== f'CREATE TABLE IF NOT EXISTS "{catalog_name}"."schema"."test_table" AS SELECT 1 AS "col"'
685686
)
687+
688+
689+
@pytest.mark.parametrize(
690+
"catalog_type_overrides", [{}, {"my_catalog": "hive"}, {"other_catalog": "iceberg"}]
691+
)
692+
def test_insert_overwrite_time_partition_hive(
693+
make_mocked_engine_adapter: t.Callable, catalog_type_overrides: t.Dict[str, str]
694+
):
695+
config = TrinoConnectionConfig(
696+
user="user",
697+
host="host",
698+
catalog="catalog",
699+
catalog_type_overrides=catalog_type_overrides,
700+
)
701+
adapter: TrinoEngineAdapter = make_mocked_engine_adapter(
702+
TrinoEngineAdapter, catalog_type_overrides=config.catalog_type_overrides
703+
)
704+
adapter.fetchone = MagicMock(return_value=None) # type: ignore
705+
706+
adapter.insert_overwrite_by_time_partition(
707+
table_name=".".join(["my_catalog", "schema", "test_table"]),
708+
query_or_df=parse_one("SELECT a, b FROM tbl"),
709+
start="2022-01-01",
710+
end="2022-01-02",
711+
time_column="b",
712+
time_formatter=lambda x, _: exp.Literal.string(to_ds(x)),
713+
target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")},
714+
)
715+
716+
assert to_sql_calls(adapter) == [
717+
"SET SESSION my_catalog.insert_existing_partitions_behavior='OVERWRITE'",
718+
'INSERT INTO "my_catalog"."schema"."test_table" ("a", "b") SELECT "a", "b" FROM (SELECT "a", "b" FROM "tbl") AS "_subquery" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'',
719+
"SET SESSION my_catalog.insert_existing_partitions_behavior='APPEND'",
720+
]
721+
722+
723+
@pytest.mark.parametrize(
724+
"catalog_type_overrides",
725+
[
726+
{"my_catalog": "iceberg"},
727+
{"my_catalog": "unknown"},
728+
],
729+
)
730+
def test_insert_overwrite_time_partition_iceberg(
731+
make_mocked_engine_adapter: t.Callable, catalog_type_overrides: t.Dict[str, str]
732+
):
733+
config = TrinoConnectionConfig(
734+
user="user",
735+
host="host",
736+
catalog="catalog",
737+
catalog_type_overrides=catalog_type_overrides,
738+
)
739+
adapter: TrinoEngineAdapter = make_mocked_engine_adapter(
740+
TrinoEngineAdapter, catalog_type_overrides=config.catalog_type_overrides
741+
)
742+
adapter.fetchone = MagicMock(return_value=None) # type: ignore
743+
744+
adapter.insert_overwrite_by_time_partition(
745+
table_name=".".join(["my_catalog", "schema", "test_table"]),
746+
query_or_df=parse_one("SELECT a, b FROM tbl"),
747+
start="2022-01-01",
748+
end="2022-01-02",
749+
time_column="b",
750+
time_formatter=lambda x, _: exp.Literal.string(to_ds(x)),
751+
target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")},
752+
)
753+
754+
assert to_sql_calls(adapter) == [
755+
'DELETE FROM "my_catalog"."schema"."test_table" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'',
756+
'INSERT INTO "my_catalog"."schema"."test_table" ("a", "b") SELECT "a", "b" FROM (SELECT "a", "b" FROM "tbl") AS "_subquery" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'',
757+
]

0 commit comments

Comments
 (0)