Skip to content

Commit 431e2fa

Browse files
committed
Fix!: Suppor forward-only changes of model kinds under certain circumstances (#5028)
1 parent 314d6c7 commit 431e2fa

File tree

4 files changed

+98
-15
lines changed

4 files changed

+98
-15
lines changed

sqlmesh/core/plan/builder.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ def _categorize_snapshots(
592592
# If the model kind changes mark as breaking
593593
if snapshot.is_model and snapshot.name in self._context_diff.modified_snapshots:
594594
_, old = self._context_diff.modified_snapshots[snapshot.name]
595-
if old.model.kind.name != snapshot.model.kind.name:
595+
if _is_breaking_kind_change(old, snapshot):
596596
category = SnapshotChangeCategory.BREAKING
597597

598598
snapshot.categorize_as(category)
@@ -756,8 +756,8 @@ def _is_forward_only_change(self, s_id: SnapshotId) -> bool:
756756
snapshot = self._context_diff.snapshots[s_id]
757757
if snapshot.name in self._context_diff.modified_snapshots:
758758
_, old = self._context_diff.modified_snapshots[snapshot.name]
759-
# If the model kind has changed, then we should not consider this to be a forward-only change.
760-
if snapshot.is_model and old.model.kind.name != snapshot.model.kind.name:
759+
# If the model kind has changed in a breaking way, then we can't consider this to be a forward-only change.
760+
if snapshot.is_model and _is_breaking_kind_change(old, snapshot):
761761
return False
762762
return (
763763
snapshot.is_model
@@ -873,3 +873,16 @@ def _modified_and_added_snapshots(self) -> t.List[Snapshot]:
873873
if snapshot.name in self._context_diff.modified_snapshots
874874
or snapshot.snapshot_id in self._context_diff.added
875875
]
876+
877+
878+
def _is_breaking_kind_change(old: Snapshot, new: Snapshot) -> bool:
879+
if old.model.kind.name == new.model.kind.name:
880+
# If the kind hasn't changed, then it's not a breaking change
881+
return False
882+
if not old.is_incremental or not new.is_incremental:
883+
# If either is not incremental, then it's a breaking change
884+
return True
885+
if old.model.partitioned_by == new.model.partitioned_by:
886+
# If the partitioning hasn't changed, then it's not a breaking change
887+
return False
888+
return True

sqlmesh/dbt/model.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,13 @@ def model_kind(self, context: DbtContext) -> ModelKind:
257257
if field_val is not None:
258258
incremental_by_kind_kwargs[field] = field_val
259259

260+
disable_restatement = self.disable_restatement
261+
if disable_restatement is None:
262+
disable_restatement = (
263+
not self.full_refresh if self.full_refresh is not None else False
264+
)
265+
incremental_kind_kwargs["disable_restatement"] = disable_restatement
266+
260267
if self.time_column:
261268
strategy = self.incremental_strategy or target.default_incremental_strategy(
262269
IncrementalByTimeRangeKind
@@ -270,20 +277,11 @@ def model_kind(self, context: DbtContext) -> ModelKind:
270277

271278
return IncrementalByTimeRangeKind(
272279
time_column=self.time_column,
273-
disable_restatement=(
274-
self.disable_restatement if self.disable_restatement is not None else False
275-
),
276280
auto_restatement_intervals=self.auto_restatement_intervals,
277281
**incremental_kind_kwargs,
278282
**incremental_by_kind_kwargs,
279283
)
280284

281-
disable_restatement = self.disable_restatement
282-
if disable_restatement is None:
283-
disable_restatement = (
284-
not self.full_refresh if self.full_refresh is not None else False
285-
)
286-
287285
if self.unique_key:
288286
strategy = self.incremental_strategy or target.default_incremental_strategy(
289287
IncrementalByUniqueKeyKind
@@ -309,7 +307,6 @@ def model_kind(self, context: DbtContext) -> ModelKind:
309307

310308
return IncrementalByUniqueKeyKind(
311309
unique_key=self.unique_key,
312-
disable_restatement=disable_restatement,
313310
**incremental_kind_kwargs,
314311
**incremental_by_kind_kwargs,
315312
)
@@ -319,7 +316,6 @@ def model_kind(self, context: DbtContext) -> ModelKind:
319316
)
320317
return IncrementalUnmanagedKind(
321318
insert_overwrite=strategy in INCREMENTAL_BY_TIME_STRATEGIES,
322-
disable_restatement=disable_restatement,
323319
**incremental_kind_kwargs,
324320
)
325321
if materialization == Materialization.EPHEMERAL:

tests/core/test_plan.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,17 @@
88
from tests.core.test_table_diff import create_test_console
99
import time_machine
1010
from pytest_mock.plugin import MockerFixture
11-
from sqlglot import parse_one
11+
from sqlglot import parse_one, exp
1212

13+
from sqlmesh.core import dialect as d
1314
from sqlmesh.core.context import Context
1415
from sqlmesh.core.context_diff import ContextDiff
1516
from sqlmesh.core.environment import EnvironmentNamingInfo, EnvironmentStatements
1617
from sqlmesh.core.model import (
1718
ExternalModel,
1819
FullKind,
1920
IncrementalByTimeRangeKind,
21+
IncrementalUnmanagedKind,
2022
SeedKind,
2123
SeedModel,
2224
SqlModel,
@@ -1723,6 +1725,60 @@ def test_forward_only_models_model_kind_changed(make_snapshot, mocker: MockerFix
17231725
assert updated_snapshot.change_category == SnapshotChangeCategory.BREAKING
17241726

17251727

1728+
@pytest.mark.parametrize(
1729+
"partitioned_by, expected_change_category",
1730+
[
1731+
([], SnapshotChangeCategory.BREAKING),
1732+
([d.parse_one("ds")], SnapshotChangeCategory.FORWARD_ONLY),
1733+
],
1734+
)
1735+
def test_forward_only_models_model_kind_changed_to_incremental_by_time_range(
1736+
make_snapshot,
1737+
partitioned_by: t.List[exp.Expression],
1738+
expected_change_category: SnapshotChangeCategory,
1739+
):
1740+
snapshot = make_snapshot(
1741+
SqlModel(
1742+
name="a",
1743+
query=parse_one("select 1, ds"),
1744+
kind=IncrementalUnmanagedKind(),
1745+
partitioned_by=partitioned_by,
1746+
)
1747+
)
1748+
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)
1749+
updated_snapshot = make_snapshot(
1750+
SqlModel(
1751+
name="a",
1752+
query=parse_one("select 3, ds"),
1753+
kind=IncrementalByTimeRangeKind(time_column="ds", forward_only=True),
1754+
)
1755+
)
1756+
updated_snapshot.previous_versions = snapshot.all_versions
1757+
1758+
context_diff = ContextDiff(
1759+
environment="test_environment",
1760+
is_new_environment=True,
1761+
is_unfinalized_environment=False,
1762+
normalize_environment_name=True,
1763+
create_from="prod",
1764+
create_from_env_exists=True,
1765+
added=set(),
1766+
removed_snapshots={},
1767+
modified_snapshots={updated_snapshot.name: (updated_snapshot, snapshot)},
1768+
snapshots={updated_snapshot.snapshot_id: updated_snapshot},
1769+
new_snapshots={updated_snapshot.snapshot_id: updated_snapshot},
1770+
previous_plan_id=None,
1771+
previously_promoted_snapshot_ids=set(),
1772+
previous_finalized_snapshots=None,
1773+
previous_gateway_managed_virtual_layer=False,
1774+
gateway_managed_virtual_layer=False,
1775+
environment_statements=[],
1776+
)
1777+
1778+
PlanBuilder(context_diff, is_dev=True).build()
1779+
assert updated_snapshot.change_category == expected_change_category
1780+
1781+
17261782
def test_indirectly_modified_forward_only_model(make_snapshot, mocker: MockerFixture):
17271783
snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("select 1 as a, ds")))
17281784
snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING)

tests/dbt/test_transformation.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,24 @@ def test_model_kind():
244244
time_column="foo", dialect="duckdb", forward_only=True, disable_restatement=False
245245
)
246246

247+
assert ModelConfig(
248+
materialized=Materialization.INCREMENTAL,
249+
time_column="foo",
250+
incremental_strategy="merge",
251+
full_refresh=True,
252+
).model_kind(context) == IncrementalByTimeRangeKind(
253+
time_column="foo", dialect="duckdb", forward_only=True, disable_restatement=False
254+
)
255+
256+
assert ModelConfig(
257+
materialized=Materialization.INCREMENTAL,
258+
time_column="foo",
259+
incremental_strategy="merge",
260+
full_refresh=False,
261+
).model_kind(context) == IncrementalByTimeRangeKind(
262+
time_column="foo", dialect="duckdb", forward_only=True, disable_restatement=True
263+
)
264+
247265
assert ModelConfig(
248266
materialized=Materialization.INCREMENTAL,
249267
time_column="foo",

0 commit comments

Comments
 (0)