Skip to content

Commit 26ebace

Browse files
authored
Feat: prevent other processes seeing missing intervals during restatement (#5285)
1 parent 85cfe5a commit 26ebace

File tree

7 files changed

+986
-66
lines changed

7 files changed

+986
-66
lines changed

sqlmesh/core/console.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,22 @@ def log_skipped_models(self, snapshot_names: t.Set[str]) -> None:
551551
def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None:
552552
"""Display list of models that failed during evaluation to the user."""
553553

554+
@abc.abstractmethod
555+
def log_models_updated_during_restatement(
556+
self,
557+
snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]],
558+
environment_naming_info: EnvironmentNamingInfo,
559+
default_catalog: t.Optional[str],
560+
) -> None:
561+
"""Display a list of models where new versions got deployed to the specified :environment while we were restating data the old versions
562+
563+
Args:
564+
snapshots: a list of (snapshot_we_restated, snapshot_it_got_replaced_with_during_restatement) tuples
565+
environment: which environment got updated while we were restating models
566+
environment_naming_info: how snapshots are named in that :environment (for display name purposes)
567+
default_catalog: the configured default catalog (for display name purposes)
568+
"""
569+
554570
@abc.abstractmethod
555571
def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID:
556572
"""Starts loading and returns a unique ID that can be used to stop the loading. Optionally can display a message."""
@@ -771,6 +787,14 @@ def log_skipped_models(self, snapshot_names: t.Set[str]) -> None:
771787
def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None:
772788
pass
773789

790+
def log_models_updated_during_restatement(
791+
self,
792+
snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]],
793+
environment_naming_info: EnvironmentNamingInfo,
794+
default_catalog: t.Optional[str],
795+
) -> None:
796+
pass
797+
774798
def log_destructive_change(
775799
self,
776800
snapshot_name: str,
@@ -2225,6 +2249,30 @@ def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None:
22252249
for node_name, msg in error_messages.items():
22262250
self._print(f" [red]{node_name}[/red]\n\n{msg}")
22272251

2252+
def log_models_updated_during_restatement(
2253+
self,
2254+
snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]],
2255+
environment_naming_info: EnvironmentNamingInfo,
2256+
default_catalog: t.Optional[str] = None,
2257+
) -> None:
2258+
if snapshots:
2259+
tree = Tree(
2260+
f"[yellow]The following models had new versions deployed while data was being restated:[/yellow]"
2261+
)
2262+
2263+
for restated_snapshot, updated_snapshot in snapshots:
2264+
display_name = restated_snapshot.display_name(
2265+
environment_naming_info,
2266+
default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None,
2267+
dialect=self.dialect,
2268+
)
2269+
current_branch = tree.add(display_name)
2270+
current_branch.add(f"restated version: '{restated_snapshot.version}'")
2271+
current_branch.add(f"currently active version: '{updated_snapshot.version}'")
2272+
2273+
self._print(tree)
2274+
self._print("") # newline spacer
2275+
22282276
def log_destructive_change(
22292277
self,
22302278
snapshot_name: str,

sqlmesh/core/plan/evaluator.py

Lines changed: 69 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from sqlmesh.core.console import Console, get_console
2323
from sqlmesh.core.environment import EnvironmentNamingInfo, execute_environment_statements
2424
from sqlmesh.core.macros import RuntimeStage
25-
from sqlmesh.core.snapshot.definition import to_view_mapping
25+
from sqlmesh.core.snapshot.definition import to_view_mapping, SnapshotTableInfo
2626
from sqlmesh.core.plan import stages
2727
from sqlmesh.core.plan.definition import EvaluatablePlan
2828
from sqlmesh.core.scheduler import Scheduler
@@ -40,7 +40,7 @@
4040
from sqlmesh.core.plan.common import identify_restatement_intervals_across_snapshot_versions
4141
from sqlmesh.utils import CorrelationId
4242
from sqlmesh.utils.concurrency import NodeExecutionFailedError
43-
from sqlmesh.utils.errors import PlanError, SQLMeshError
43+
from sqlmesh.utils.errors import PlanError, ConflictingPlanError, SQLMeshError
4444
from sqlmesh.utils.date import now, to_timestamp
4545

4646
logger = logging.getLogger(__name__)
@@ -287,34 +287,78 @@ def visit_audit_only_run_stage(
287287
def visit_restatement_stage(
288288
self, stage: stages.RestatementStage, plan: EvaluatablePlan
289289
) -> None:
290-
snapshot_intervals_to_restate = {
291-
(s.id_and_version, i) for s, i in stage.snapshot_intervals.items()
292-
}
293-
294-
# Restating intervals on prod plans should mean that the intervals are cleared across
295-
# all environments, not just the version currently in prod
296-
# This ensures that work done in dev environments can still be promoted to prod
297-
# by forcing dev environments to re-run intervals that changed in prod
290+
# Restating intervals on prod plans means that once the data for the intervals being restated has been backfilled
291+
# (which happens in the backfill stage) then we need to clear those intervals *from state* across all other environments.
292+
#
293+
# This ensures that work done in dev environments can still be promoted to prod by forcing dev environments to
294+
# re-run intervals that changed in prod (because after this stage runs they are cleared from state and thus show as missing)
295+
#
296+
# It also means that any new dev environments created while this restatement plan was running also get the
297+
# correct intervals cleared because we look up matching snapshots as at right now and not as at the time the plan
298+
# was created, which could have been several hours ago if there was a lot of data to restate.
298299
#
299300
# Without this rule, its possible that promoting a dev table to prod will introduce old data to prod
300-
snapshot_intervals_to_restate.update(
301-
{
302-
(s.snapshot, s.interval)
303-
for s in identify_restatement_intervals_across_snapshot_versions(
304-
state_reader=self.state_sync,
305-
prod_restatements=plan.restatements,
306-
disable_restatement_models=plan.disabled_restatement_models,
307-
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
308-
current_ts=to_timestamp(plan.execution_time or now()),
309-
).values()
310-
}
311-
)
312301

313-
self.state_sync.remove_intervals(
314-
snapshot_intervals=list(snapshot_intervals_to_restate),
315-
remove_shared_versions=plan.is_prod,
302+
intervals_to_clear = identify_restatement_intervals_across_snapshot_versions(
303+
state_reader=self.state_sync,
304+
prod_restatements=plan.restatements,
305+
disable_restatement_models=plan.disabled_restatement_models,
306+
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
307+
current_ts=to_timestamp(plan.execution_time or now()),
316308
)
317309

310+
if not intervals_to_clear:
311+
# Nothing to do
312+
return
313+
314+
# While the restatements were being processed, did any of the snapshots being restated get new versions deployed?
315+
# If they did, they will not reflect the data that just got restated, so we need to notify the user
316+
deployed_during_restatement: t.Dict[
317+
str, t.Tuple[SnapshotTableInfo, SnapshotTableInfo]
318+
] = {} # tuple of (restated_snapshot, current_prod_snapshot)
319+
320+
if deployed_env := self.state_sync.get_environment(plan.environment.name):
321+
promoted_snapshots_by_name = {s.name: s for s in deployed_env.snapshots}
322+
323+
for name in plan.restatements:
324+
snapshot = stage.all_snapshots[name]
325+
version = snapshot.table_info.version
326+
if (
327+
prod_snapshot := promoted_snapshots_by_name.get(name)
328+
) and prod_snapshot.version != version:
329+
deployed_during_restatement[name] = (
330+
snapshot.table_info,
331+
prod_snapshot.table_info,
332+
)
333+
334+
# we need to *not* clear the intervals on the snapshots where new versions were deployed while the restatement was running in order to prevent
335+
# subsequent plans from having unexpected intervals to backfill.
336+
# we instead list the affected models and abort the plan with an error so the user can decide what to do
337+
# (either re-attempt the restatement plan or leave things as they are)
338+
filtered_intervals_to_clear = [
339+
(s.snapshot, s.interval)
340+
for s in intervals_to_clear.values()
341+
if s.snapshot.name not in deployed_during_restatement
342+
]
343+
344+
if filtered_intervals_to_clear:
345+
# We still clear intervals in other envs for models that were successfully restated without having new versions promoted during restatement
346+
self.state_sync.remove_intervals(
347+
snapshot_intervals=filtered_intervals_to_clear,
348+
remove_shared_versions=plan.is_prod,
349+
)
350+
351+
if deployed_env and deployed_during_restatement:
352+
self.console.log_models_updated_during_restatement(
353+
list(deployed_during_restatement.values()),
354+
plan.environment.naming_info,
355+
self.default_catalog,
356+
)
357+
raise ConflictingPlanError(
358+
f"Another plan ({deployed_env.summary.plan_id}) deployed new versions of {len(deployed_during_restatement)} models in the target environment '{plan.environment.name}' while they were being restated by this plan.\n"
359+
"Please re-apply your plan if these new versions should be restated."
360+
)
361+
318362
def visit_environment_record_update_stage(
319363
self, stage: stages.EnvironmentRecordUpdateStage, plan: EvaluatablePlan
320364
) -> None:

sqlmesh/core/plan/explainer.py

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,27 @@
1+
from __future__ import annotations
2+
13
import abc
24
import typing as t
35
import logging
6+
from dataclasses import dataclass
47

58
from rich.console import Console as RichConsole
69
from rich.tree import Tree
710
from sqlglot.dialects.dialect import DialectType
811
from sqlmesh.core import constants as c
912
from sqlmesh.core.console import Console, TerminalConsole, get_console
1013
from sqlmesh.core.environment import EnvironmentNamingInfo
14+
from sqlmesh.core.plan.common import (
15+
SnapshotIntervalClearRequest,
16+
identify_restatement_intervals_across_snapshot_versions,
17+
)
1118
from sqlmesh.core.plan.definition import EvaluatablePlan, SnapshotIntervals
1219
from sqlmesh.core.plan import stages
1320
from sqlmesh.core.plan.evaluator import (
1421
PlanEvaluator,
1522
)
1623
from sqlmesh.core.state_sync import StateReader
17-
from sqlmesh.core.snapshot.definition import (
18-
SnapshotInfoMixin,
19-
)
24+
from sqlmesh.core.snapshot.definition import SnapshotInfoMixin, SnapshotIdAndVersion
2025
from sqlmesh.utils import Verbosity, rich as srich, to_snake_case
2126
from sqlmesh.utils.date import to_ts
2227
from sqlmesh.utils.errors import SQLMeshError
@@ -45,6 +50,15 @@ def evaluate(
4550
explainer_console = _get_explainer_console(
4651
self.console, plan.environment, self.default_catalog
4752
)
53+
54+
# add extra metadata that's only needed at this point for better --explain output
55+
plan_stages = [
56+
ExplainableRestatementStage.from_restatement_stage(stage, self.state_reader, plan)
57+
if isinstance(stage, stages.RestatementStage)
58+
else stage
59+
for stage in plan_stages
60+
]
61+
4862
explainer_console.explain(plan_stages)
4963

5064

@@ -54,6 +68,38 @@ def explain(self, stages: t.List[stages.PlanStage]) -> None:
5468
pass
5569

5670

71+
@dataclass
72+
class ExplainableRestatementStage(stages.RestatementStage):
73+
"""
74+
This brings forward some calculations that would usually be done in the evaluator so the user can be given a better indication
75+
of what might happen when they ask for the plan to be explained
76+
"""
77+
78+
snapshot_intervals_to_clear: t.Dict[str, SnapshotIntervalClearRequest]
79+
"""Which snapshots from other environments would have intervals cleared as part of restatement, keyed by name"""
80+
81+
@classmethod
82+
def from_restatement_stage(
83+
cls: t.Type[ExplainableRestatementStage],
84+
stage: stages.RestatementStage,
85+
state_reader: StateReader,
86+
plan: EvaluatablePlan,
87+
) -> ExplainableRestatementStage:
88+
all_restatement_intervals = identify_restatement_intervals_across_snapshot_versions(
89+
state_reader=state_reader,
90+
prod_restatements=plan.restatements,
91+
disable_restatement_models=plan.disabled_restatement_models,
92+
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
93+
)
94+
95+
return cls(
96+
snapshot_intervals_to_clear={
97+
s.snapshot.name: s for s in all_restatement_intervals.values()
98+
},
99+
all_snapshots=stage.all_snapshots,
100+
)
101+
102+
57103
MAX_TREE_LENGTH = 10
58104

59105

@@ -146,11 +192,22 @@ def visit_audit_only_run_stage(self, stage: stages.AuditOnlyRunStage) -> Tree:
146192
tree.add(display_name)
147193
return tree
148194

149-
def visit_restatement_stage(self, stage: stages.RestatementStage) -> Tree:
195+
def visit_explainable_restatement_stage(self, stage: ExplainableRestatementStage) -> Tree:
196+
return self.visit_restatement_stage(stage)
197+
198+
def visit_restatement_stage(
199+
self, stage: t.Union[ExplainableRestatementStage, stages.RestatementStage]
200+
) -> Tree:
150201
tree = Tree("[bold]Invalidate data intervals as part of restatement[/bold]")
151-
for snapshot_table_info, interval in stage.snapshot_intervals.items():
152-
display_name = self._display_name(snapshot_table_info)
153-
tree.add(f"{display_name} [{to_ts(interval[0])} - {to_ts(interval[1])}]")
202+
203+
if isinstance(stage, ExplainableRestatementStage) and (
204+
snapshot_intervals := stage.snapshot_intervals_to_clear
205+
):
206+
for clear_request in snapshot_intervals.values():
207+
display_name = self._display_name(clear_request.snapshot)
208+
interval = clear_request.interval
209+
tree.add(f"{display_name} [{to_ts(interval[0])} - {to_ts(interval[1])}]")
210+
154211
return tree
155212

156213
def visit_backfill_stage(self, stage: stages.BackfillStage) -> Tree:
@@ -265,12 +322,14 @@ def visit_finalize_environment_stage(
265322

266323
def _display_name(
267324
self,
268-
snapshot: SnapshotInfoMixin,
325+
snapshot: t.Union[SnapshotInfoMixin, SnapshotIdAndVersion],
269326
environment_naming_info: t.Optional[EnvironmentNamingInfo] = None,
270327
) -> str:
271328
return snapshot.display_name(
272-
environment_naming_info or self.environment_naming_info,
273-
self.default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None,
329+
environment_naming_info=environment_naming_info or self.environment_naming_info,
330+
default_catalog=self.default_catalog
331+
if self.verbosity < Verbosity.VERY_VERBOSE
332+
else None,
274333
dialect=self.dialect,
275334
)
276335

sqlmesh/core/plan/stages.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
Snapshot,
1313
SnapshotTableInfo,
1414
SnapshotId,
15-
Interval,
1615
)
1716

1817

@@ -98,14 +97,19 @@ class AuditOnlyRunStage:
9897

9998
@dataclass
10099
class RestatementStage:
101-
"""Restate intervals for given snapshots.
100+
"""Clear intervals from state for snapshots in *other* environments, when restatements are requested in prod.
101+
102+
This stage is effectively a "marker" stage to trigger the plan evaluator to perform the "clear intervals" logic after the BackfillStage has completed.
103+
The "clear intervals" logic is executed just-in-time using the latest state available in order to pick up new snapshots that may have
104+
been created while the BackfillStage was running, which is why we do not build a list of snapshots to clear at plan time and defer to evaluation time.
105+
106+
Note that this stage is only present on `prod` plans because dev plans do not need to worry about clearing intervals in other environments.
102107
103108
Args:
104-
snapshot_intervals: Intervals to restate.
105-
all_snapshots: All snapshots in the plan by name.
109+
all_snapshots: All snapshots in the plan by name. Note that this does not include the snapshots from other environments that will get their
110+
intervals cleared, it's included here as an optimization to prevent having to re-fetch the current plan's snapshots
106111
"""
107112

108-
snapshot_intervals: t.Dict[SnapshotTableInfo, Interval]
109113
all_snapshots: t.Dict[str, Snapshot]
110114

111115

@@ -321,10 +325,6 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]:
321325
if audit_only_snapshots:
322326
stages.append(AuditOnlyRunStage(snapshots=list(audit_only_snapshots.values())))
323327

324-
restatement_stage = self._get_restatement_stage(plan, snapshots_by_name)
325-
if restatement_stage:
326-
stages.append(restatement_stage)
327-
328328
if missing_intervals_before_promote:
329329
stages.append(
330330
BackfillStage(
@@ -349,6 +349,15 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]:
349349
)
350350
)
351351

352+
# note: "restatement stage" (which is clearing intervals in state - not actually performing the restatements, that's the backfill stage)
353+
# needs to come *after* the backfill stage so that at no time do other plans / runs see empty prod intervals and compete with this plan to try to fill them.
354+
# in addition, when we update intervals in state, we only clear intervals from dev snapshots to force dev models to be backfilled based on the new prod data.
355+
# we can leave prod intervals alone because by the time this plan finishes, the intervals in state have not actually changed, since restatement replaces
356+
# data for existing intervals and does not produce new ones
357+
restatement_stage = self._get_restatement_stage(plan, snapshots_by_name)
358+
if restatement_stage:
359+
stages.append(restatement_stage)
360+
352361
stages.append(
353362
EnvironmentRecordUpdateStage(
354363
no_gaps_snapshot_names={s.name for s in before_promote_snapshots}
@@ -443,15 +452,12 @@ def _get_after_all_stage(
443452
def _get_restatement_stage(
444453
self, plan: EvaluatablePlan, snapshots_by_name: t.Dict[str, Snapshot]
445454
) -> t.Optional[RestatementStage]:
446-
snapshot_intervals_to_restate = {}
447-
for name, interval in plan.restatements.items():
448-
restated_snapshot = snapshots_by_name[name]
449-
restated_snapshot.remove_interval(interval)
450-
snapshot_intervals_to_restate[restated_snapshot.table_info] = interval
451-
if not snapshot_intervals_to_restate or plan.is_dev:
455+
if not plan.restatements or plan.is_dev:
456+
# The RestatementStage to clear intervals from state across all environments is not needed for plans against dev, only prod
452457
return None
458+
453459
return RestatementStage(
454-
snapshot_intervals=snapshot_intervals_to_restate, all_snapshots=snapshots_by_name
460+
all_snapshots=snapshots_by_name,
455461
)
456462

457463
def _get_physical_layer_update_stage(

0 commit comments

Comments
 (0)