Skip to content

Commit 8dd5e38

Browse files
authored
Feat: Improve CLI and --explain output for restatements (#5348)
1 parent e949176 commit 8dd5e38

File tree

6 files changed

+77
-15
lines changed

6 files changed

+77
-15
lines changed

sqlmesh/core/console.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2022,7 +2022,34 @@ def _prompt_categorize(
20222022
plan = plan_builder.build()
20232023

20242024
if plan.restatements:
2025-
self._print("\n[bold]Restating models\n")
2025+
# A plan can have restatements for the following reasons:
2026+
# - The user specifically called `sqlmesh plan` with --restate-model.
2027+
# This creates a "restatement plan" which disallows all other changes and simply force-backfills
2028+
# the selected models and their downstream dependencies using the versions of the models stored in state.
2029+
# - There are no specific restatements (so changes are allowed) AND dev previews need to be computed.
2030+
# The "restatements" feature is currently reused for dev previews.
2031+
if plan.selected_models_to_restate:
2032+
# There were legitimate restatements, no dev previews
2033+
tree = Tree(
2034+
"[bold]Models selected for restatement:[/bold]\n"
2035+
"This causes backfill of the model itself as well as affected downstream models"
2036+
)
2037+
model_fqn_to_snapshot = {s.name: s for s in plan.snapshots.values()}
2038+
for model_fqn in plan.selected_models_to_restate:
2039+
snapshot = model_fqn_to_snapshot[model_fqn]
2040+
display_name = snapshot.display_name(
2041+
plan.environment_naming_info,
2042+
default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None,
2043+
dialect=self.dialect,
2044+
)
2045+
tree.add(
2046+
display_name
2047+
) # note: we deliberately dont show any intervals here; they get shown in the backfill section
2048+
self._print(tree)
2049+
else:
2050+
# We are computing dev previews, do not confuse the user by printing out something to do
2051+
# with restatements. Dev previews are already highlighted in the backfill step
2052+
pass
20262053
else:
20272054
self.show_environment_difference_summary(
20282055
plan.context_diff,

sqlmesh/core/plan/builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ def build(self) -> Plan:
338338
directly_modified=directly_modified,
339339
indirectly_modified=indirectly_modified,
340340
deployability_index=deployability_index,
341+
selected_models_to_restate=self._restate_models,
341342
restatements=restatements,
342343
start_override_per_model=self._start_override_per_model,
343344
end_override_per_model=end_override_per_model,

sqlmesh/core/plan/definition.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,16 @@ class Plan(PydanticModel, frozen=True):
5858
indirectly_modified: t.Dict[SnapshotId, t.Set[SnapshotId]]
5959

6060
deployability_index: DeployabilityIndex
61+
selected_models_to_restate: t.Optional[t.Set[str]] = None
62+
"""Models that have been explicitly selected for restatement by a user"""
6163
restatements: t.Dict[SnapshotId, Interval]
64+
"""
65+
All models being restated, which are typically the explicitly selected ones + their downstream dependencies.
66+
67+
Note that dev previews are also considered restatements, so :selected_models_to_restate can be empty
68+
while :restatements is still populated with dev previews
69+
"""
70+
6271
start_override_per_model: t.Optional[t.Dict[str, datetime]]
6372
end_override_per_model: t.Optional[t.Dict[str, datetime]]
6473

sqlmesh/core/plan/explainer.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import typing as t
55
import logging
66
from dataclasses import dataclass
7+
from collections import defaultdict
78

89
from rich.console import Console as RichConsole
910
from rich.tree import Tree
@@ -21,7 +22,11 @@
2122
PlanEvaluator,
2223
)
2324
from sqlmesh.core.state_sync import StateReader
24-
from sqlmesh.core.snapshot.definition import SnapshotInfoMixin, SnapshotIdAndVersion
25+
from sqlmesh.core.snapshot.definition import (
26+
SnapshotInfoMixin,
27+
SnapshotIdAndVersion,
28+
model_display_name,
29+
)
2530
from sqlmesh.utils import Verbosity, rich as srich, to_snake_case
2631
from sqlmesh.utils.date import to_ts
2732
from sqlmesh.utils.errors import SQLMeshError
@@ -75,8 +80,8 @@ class ExplainableRestatementStage(stages.RestatementStage):
7580
of what might happen when they ask for the plan to be explained
7681
"""
7782

78-
snapshot_intervals_to_clear: t.Dict[str, SnapshotIntervalClearRequest]
79-
"""Which snapshots from other environments would have intervals cleared as part of restatement, keyed by name"""
83+
snapshot_intervals_to_clear: t.Dict[str, t.List[SnapshotIntervalClearRequest]]
84+
"""Which snapshots from other environments would have intervals cleared as part of restatement, grouped by name."""
8085

8186
@classmethod
8287
def from_restatement_stage(
@@ -92,10 +97,13 @@ def from_restatement_stage(
9297
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
9398
)
9499

100+
# Group the interval clear requests by snapshot name to make them easier to write to the console
101+
snapshot_intervals_to_clear = defaultdict(list)
102+
for clear_request in all_restatement_intervals.values():
103+
snapshot_intervals_to_clear[clear_request.snapshot.name].append(clear_request)
104+
95105
return cls(
96-
snapshot_intervals_to_clear={
97-
s.snapshot.name: s for s in all_restatement_intervals.values()
98-
},
106+
snapshot_intervals_to_clear=snapshot_intervals_to_clear,
99107
all_snapshots=stage.all_snapshots,
100108
)
101109

@@ -198,15 +206,30 @@ def visit_explainable_restatement_stage(self, stage: ExplainableRestatementStage
198206
def visit_restatement_stage(
199207
self, stage: t.Union[ExplainableRestatementStage, stages.RestatementStage]
200208
) -> Tree:
201-
tree = Tree("[bold]Invalidate data intervals as part of restatement[/bold]")
209+
tree = Tree(
210+
"[bold]Invalidate data intervals in state for development environments to prevent old data from being promoted[/bold]\n"
211+
"This only affects state and will not clear physical data from the tables until the next plan for each environment"
212+
)
202213

203214
if isinstance(stage, ExplainableRestatementStage) and (
204215
snapshot_intervals := stage.snapshot_intervals_to_clear
205216
):
206-
for clear_request in snapshot_intervals.values():
207-
display_name = self._display_name(clear_request.snapshot)
208-
interval = clear_request.interval
209-
tree.add(f"{display_name} [{to_ts(interval[0])} - {to_ts(interval[1])}]")
217+
for name, clear_requests in snapshot_intervals.items():
218+
display_name = model_display_name(
219+
name, self.environment_naming_info, self.default_catalog, self.dialect
220+
)
221+
interval_start = min(cr.interval[0] for cr in clear_requests)
222+
interval_end = max(cr.interval[1] for cr in clear_requests)
223+
224+
if not interval_start or not interval_end:
225+
continue
226+
227+
node = tree.add(f"{display_name} [{to_ts(interval_start)} - {to_ts(interval_end)}]")
228+
229+
all_environment_names = sorted(
230+
set(env_name for cr in clear_requests for env_name in cr.environment_names)
231+
)
232+
node.add("in environments: " + ", ".join(all_environment_names))
210233

211234
return tree
212235

tests/cli/test_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def test_plan_restate_model(runner, tmp_path):
247247
)
248248
assert result.exit_code == 0
249249
assert_duckdb_test(result)
250-
assert "Restating models" in result.output
250+
assert "Models selected for restatement" in result.output
251251
assert "sqlmesh_example.full_model [full refresh" in result.output
252252
assert_model_batches_executed(result)
253253
assert "Virtual layer updated" not in result.output

tests/core/test_plan_stages.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -771,9 +771,11 @@ def _get_snapshots(snapshot_ids: t.Iterable[SnapshotIdLike]):
771771
# note: we only clear the intervals from state for "a" in dev, we leave prod alone
772772
assert restatement_stage.snapshot_intervals_to_clear
773773
assert len(restatement_stage.snapshot_intervals_to_clear) == 1
774-
snapshot_name, clear_request = list(restatement_stage.snapshot_intervals_to_clear.items())[0]
775-
assert isinstance(clear_request, SnapshotIntervalClearRequest)
774+
snapshot_name, clear_requests = list(restatement_stage.snapshot_intervals_to_clear.items())[0]
776775
assert snapshot_name == '"a"'
776+
assert len(clear_requests) == 1
777+
clear_request = clear_requests[0]
778+
assert isinstance(clear_request, SnapshotIntervalClearRequest)
777779
assert clear_request.snapshot_id == snapshot_a_dev.snapshot_id
778780
assert clear_request.snapshot == snapshot_a_dev.id_and_version
779781
assert clear_request.interval == (to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))

0 commit comments

Comments
 (0)