diff --git a/sqlmesh_utils/materializations/non_idempotent_incremental_by_time_range.py b/sqlmesh_utils/materializations/non_idempotent_incremental_by_time_range.py index 1a7bfa7..66955b0 100644 --- a/sqlmesh_utils/materializations/non_idempotent_incremental_by_time_range.py +++ b/sqlmesh_utils/materializations/non_idempotent_incremental_by_time_range.py @@ -11,7 +11,6 @@ from sqlmesh.utils.date import TimeLike from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS from sqlmesh import CustomKind -from sqlmesh.utils import columns_to_types_all_known if t.TYPE_CHECKING: from sqlmesh.core.engine_adapter._typing import QueryOrDF @@ -76,6 +75,7 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: # sanity check @@ -88,9 +88,15 @@ def insert( start: TimeLike = kwargs["start"] end: TimeLike = kwargs["end"] - columns_to_types = model.columns_to_types - if not columns_to_types or not columns_to_types_all_known(columns_to_types): - columns_to_types = self.adapter.columns(table_name) + if is_first_insert and not self.adapter.table_exists(table_name): + self.adapter.ctas( + table_name=table_name, + query_or_df=model.ctas_query(**render_kwargs), + ) + + columns_to_types, source_columns = self._get_target_and_source_columns( + model, table_name, render_kwargs=render_kwargs + ) low, high = [ model.convert_to_time_column(dt, columns_to_types) @@ -116,9 +122,10 @@ def _inject_alias(node: exp.Expression, alias: str) -> exp.Expression: self.adapter.merge( target_table=table_name, source_table=query_or_df, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, unique_key=model.kind.primary_key, merge_filter=exp.and_(*betweens), + source_columns=source_columns, ) def append( @@ -126,6 +133,7 @@ def append( table_name: str, query_or_df: QueryOrDF, model: Model, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: self.insert( @@ -133,5 +141,6 @@ def append( query_or_df=query_or_df, model=model, is_first_insert=False, + render_kwargs=render_kwargs, **kwargs, ) diff --git a/tests/materializations/integration/test_integration_non_idempotent_incremental_by_time_range.py b/tests/materializations/integration/test_integration_non_idempotent_incremental_by_time_range.py index 4e27fd8..7a73f61 100644 --- a/tests/materializations/integration/test_integration_non_idempotent_incremental_by_time_range.py +++ b/tests/materializations/integration/test_integration_non_idempotent_incremental_by_time_range.py @@ -30,7 +30,7 @@ def test_basic_usage(project: Project): } project.engine_adapter.create_table( - upstream_table_name, columns_to_types=upstream_table_columns + upstream_table_name, target_columns_to_types=upstream_table_columns ) project.engine_adapter.insert_append( upstream_table_name, @@ -122,7 +122,7 @@ def test_partial_restatement(project: Project): } project.engine_adapter.create_table( - upstream_table_name, columns_to_types=upstream_table_columns + upstream_table_name, target_columns_to_types=upstream_table_columns ) project.engine_adapter.insert_append( upstream_table_name, @@ -174,7 +174,7 @@ def test_partial_restatement(project: Project): # change upstream data project.engine_adapter.drop_table(upstream_table_name) project.engine_adapter.create_table( - upstream_table_name, columns_to_types=upstream_table_columns + upstream_table_name, target_columns_to_types=upstream_table_columns ) project.engine_adapter.insert_append( upstream_table_name, diff --git a/tests/materializations/test_non_idempotent_incremental_by_time_range.py b/tests/materializations/test_non_idempotent_incremental_by_time_range.py index 9d37c38..3586524 100644 --- a/tests/materializations/test_non_idempotent_incremental_by_time_range.py +++ b/tests/materializations/test_non_idempotent_incremental_by_time_range.py @@ -92,9 +92,11 @@ def test_insert(make_model: ModelMaker, make_mocked_engine_adapter: MockedEngine is_first_insert=True, start=start, end=end, + render_kwargs={}, ) assert to_sql_calls(adapter) == [ + 'DESCRIBE "test"."snapshot_table"', parse_one( """ MERGE INTO "test"."snapshot_table" AS "__merge_target__" @@ -115,7 +117,7 @@ def test_insert(make_model: ModelMaker, make_mocked_engine_adapter: MockedEngine WHEN NOT MATCHED THEN INSERT ("name", "ds") VALUES ("__MERGE_SOURCE__"."name", "__MERGE_SOURCE__"."ds") """, dialect=adapter.dialect, - ).sql(dialect=adapter.dialect) + ).sql(dialect=adapter.dialect), ] @@ -135,6 +137,7 @@ def test_append(make_model: ModelMaker, make_mocked_engine_adapter: MockedEngine model=model, start=start, end=end, + render_kwargs={}, ) assert to_sql_calls(adapter) == [ @@ -158,7 +161,7 @@ def test_append(make_model: ModelMaker, make_mocked_engine_adapter: MockedEngine WHEN NOT MATCHED THEN INSERT ("name", "ds") VALUES ("__MERGE_SOURCE__"."name", "__MERGE_SOURCE__"."ds") """, dialect=adapter.dialect, - ).sql(dialect=adapter.dialect) + ).sql(dialect=adapter.dialect), ]