Skip to content

Commit 46bb346

Browse files
themisvaltinosizeigerman
authored andcommitted
Fix: Don't use SCD type 2 restatement logic in regular runs (#4976)
1 parent 2973dcf commit 46bb346

File tree

8 files changed

+210
-7
lines changed

8 files changed

+210
-7
lines changed

sqlmesh/core/engine_adapter/base.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,7 @@ def scd_type_2_by_time(
14541454
table_description: t.Optional[str] = None,
14551455
column_descriptions: t.Optional[t.Dict[str, str]] = None,
14561456
truncate: bool = False,
1457+
is_restatement: bool = False,
14571458
**kwargs: t.Any,
14581459
) -> None:
14591460
self._scd_type_2(
@@ -1470,6 +1471,7 @@ def scd_type_2_by_time(
14701471
table_description=table_description,
14711472
column_descriptions=column_descriptions,
14721473
truncate=truncate,
1474+
is_restatement=is_restatement,
14731475
**kwargs,
14741476
)
14751477

@@ -1488,6 +1490,7 @@ def scd_type_2_by_column(
14881490
table_description: t.Optional[str] = None,
14891491
column_descriptions: t.Optional[t.Dict[str, str]] = None,
14901492
truncate: bool = False,
1493+
is_restatement: bool = False,
14911494
**kwargs: t.Any,
14921495
) -> None:
14931496
self._scd_type_2(
@@ -1504,6 +1507,7 @@ def scd_type_2_by_column(
15041507
table_description=table_description,
15051508
column_descriptions=column_descriptions,
15061509
truncate=truncate,
1510+
is_restatement=is_restatement,
15071511
**kwargs,
15081512
)
15091513

@@ -1525,6 +1529,7 @@ def _scd_type_2(
15251529
table_description: t.Optional[str] = None,
15261530
column_descriptions: t.Optional[t.Dict[str, str]] = None,
15271531
truncate: bool = False,
1532+
is_restatement: bool = False,
15281533
**kwargs: t.Any,
15291534
) -> None:
15301535
def remove_managed_columns(
@@ -1710,13 +1715,15 @@ def remove_managed_columns(
17101715
target_table
17111716
)
17121717

1713-
cleanup_ts = None
17141718
if truncate:
17151719
existing_rows_query = existing_rows_query.limit(0)
1716-
else:
1717-
# If truncate is false it is not the first insert
1718-
# Determine the cleanup timestamp for restatement or a regular incremental run
1719-
cleanup_ts = to_time_column(start, time_data_type, self.dialect, nullable=True)
1720+
1721+
# Only set cleanup_ts if is_restatement is True and truncate is False (this to enable full restatement)
1722+
cleanup_ts = (
1723+
to_time_column(start, time_data_type, self.dialect, nullable=True)
1724+
if is_restatement and not truncate
1725+
else None
1726+
)
17201727

17211728
with source_queries[0] as source_query:
17221729
prefixed_columns_to_types = []
@@ -1755,7 +1762,7 @@ def remove_managed_columns(
17551762
.with_(
17561763
"static",
17571764
existing_rows_query.where(valid_to_col.is_(exp.Null()).not_())
1758-
if truncate
1765+
if cleanup_ts is None
17591766
else existing_rows_query.where(
17601767
exp.and_(
17611768
valid_to_col.is_(exp.Null().not_()),
@@ -1767,7 +1774,7 @@ def remove_managed_columns(
17671774
.with_(
17681775
"latest",
17691776
existing_rows_query.where(valid_to_col.is_(exp.Null()))
1770-
if truncate
1777+
if cleanup_ts is None
17711778
else exp.select(
17721779
*(
17731780
to_time_column(

sqlmesh/core/engine_adapter/trino.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def _scd_type_2(
267267
table_description: t.Optional[str] = None,
268268
column_descriptions: t.Optional[t.Dict[str, str]] = None,
269269
truncate: bool = False,
270+
is_restatement: bool = False,
270271
**kwargs: t.Any,
271272
) -> None:
272273
if columns_to_types and self.current_catalog_type == "delta_lake":
@@ -289,6 +290,7 @@ def _scd_type_2(
289290
table_description,
290291
column_descriptions,
291292
truncate,
293+
is_restatement,
292294
**kwargs,
293295
)
294296

sqlmesh/core/plan/evaluator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,11 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
226226
return
227227

228228
scheduler = self.create_scheduler(stage.all_snapshots.values())
229+
# Convert model name restatements to snapshot ID restatements
230+
restatements_by_snapshot_id = {
231+
stage.all_snapshots[name].snapshot_id: interval
232+
for name, interval in plan.restatements.items()
233+
}
229234
errors, _ = scheduler.run_merged_intervals(
230235
merged_intervals=stage.snapshot_to_intervals,
231236
deployability_index=stage.deployability_index,
@@ -234,6 +239,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
234239
circuit_breaker=self._circuit_breaker,
235240
start=plan.start,
236241
end=plan.end,
242+
restatements=restatements_by_snapshot_id,
237243
)
238244
if errors:
239245
raise PlanError("Plan application failed.")

sqlmesh/core/scheduler.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def evaluate(
146146
deployability_index: DeployabilityIndex,
147147
batch_index: int,
148148
environment_naming_info: t.Optional[EnvironmentNamingInfo] = None,
149+
is_restatement: bool = False,
149150
**kwargs: t.Any,
150151
) -> t.List[AuditResult]:
151152
"""Evaluate a snapshot and add the processed interval to the state sync.
@@ -177,6 +178,7 @@ def evaluate(
177178
snapshots=snapshots,
178179
deployability_index=deployability_index,
179180
batch_index=batch_index,
181+
is_restatement=is_restatement,
180182
**kwargs,
181183
)
182184
audit_results = self._audit_snapshot(
@@ -342,6 +344,7 @@ def run_merged_intervals(
342344
end: t.Optional[TimeLike] = None,
343345
run_environment_statements: bool = False,
344346
audit_only: bool = False,
347+
restatements: t.Optional[t.Dict[SnapshotId, Interval]] = None,
345348
) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]:
346349
"""Runs precomputed batches of missing intervals.
347350
@@ -416,6 +419,10 @@ def evaluate_node(node: SchedulingUnit) -> None:
416419
execution_time=execution_time,
417420
)
418421
else:
422+
# Determine if this snapshot and interval is a restatement (for SCD type 2)
423+
is_restatement = (
424+
restatements is not None and snapshot.snapshot_id in restatements
425+
)
419426
audit_results = self.evaluate(
420427
snapshot=snapshot,
421428
environment_naming_info=environment_naming_info,
@@ -424,6 +431,7 @@ def evaluate_node(node: SchedulingUnit) -> None:
424431
execution_time=execution_time,
425432
deployability_index=deployability_index,
426433
batch_index=batch_idx,
434+
is_restatement=is_restatement,
427435
)
428436

429437
evaluation_duration_ms = now_timestamp() - execution_start_ts
@@ -629,6 +637,7 @@ def _run_or_audit(
629637
end=end,
630638
run_environment_statements=run_environment_statements,
631639
audit_only=audit_only,
640+
restatements=remove_intervals,
632641
)
633642

634643
return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS

sqlmesh/core/snapshot/evaluator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def evaluate(
139139
snapshots: t.Dict[str, Snapshot],
140140
deployability_index: t.Optional[DeployabilityIndex] = None,
141141
batch_index: int = 0,
142+
is_restatement: bool = False,
142143
**kwargs: t.Any,
143144
) -> t.Optional[str]:
144145
"""Renders the snapshot's model, executes it and stores the result in the snapshot's physical table.
@@ -164,6 +165,7 @@ def evaluate(
164165
snapshots,
165166
deployability_index=deployability_index,
166167
batch_index=batch_index,
168+
is_restatement=is_restatement,
167169
**kwargs,
168170
)
169171
if result is None or isinstance(result, str):
@@ -613,6 +615,7 @@ def _evaluate_snapshot(
613615
limit: t.Optional[int] = None,
614616
deployability_index: t.Optional[DeployabilityIndex] = None,
615617
batch_index: int = 0,
618+
is_restatement: bool = False,
616619
**kwargs: t.Any,
617620
) -> DF | str | None:
618621
"""Renders the snapshot's model and executes it. The return value depends on whether the limit was specified.
@@ -685,6 +688,7 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None:
685688
end=end,
686689
execution_time=execution_time,
687690
physical_properties=rendered_physical_properties,
691+
is_restatement=is_restatement,
688692
)
689693
else:
690694
logger.info(
@@ -706,6 +710,7 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None:
706710
end=end,
707711
execution_time=execution_time,
708712
physical_properties=rendered_physical_properties,
713+
is_restatement=is_restatement,
709714
)
710715

711716
with (
@@ -1789,6 +1794,7 @@ def insert(
17891794
column_descriptions=model.column_descriptions,
17901795
truncate=is_first_insert,
17911796
start=kwargs["start"],
1797+
is_restatement=kwargs.get("is_restatement", False),
17921798
)
17931799
elif isinstance(model.kind, SCDType2ByColumnKind):
17941800
self.adapter.scd_type_2_by_column(
@@ -1807,6 +1813,7 @@ def insert(
18071813
column_descriptions=model.column_descriptions,
18081814
truncate=is_first_insert,
18091815
start=kwargs["start"],
1816+
is_restatement=kwargs.get("is_restatement", False),
18101817
)
18111818
else:
18121819
raise SQLMeshError(

tests/core/engine_adapter/test_base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,6 +1223,7 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable):
12231223
},
12241224
execution_time=datetime(2020, 1, 1, 0, 0, 0),
12251225
start=datetime(2020, 1, 1, 0, 0, 0),
1226+
is_restatement=True,
12261227
)
12271228

12281229
assert (
@@ -1422,6 +1423,7 @@ def test_scd_type_2_by_time_no_invalidate_hard_deletes(make_mocked_engine_adapte
14221423
},
14231424
execution_time=datetime(2020, 1, 1, 0, 0, 0),
14241425
start=datetime(2020, 1, 1, 0, 0, 0),
1426+
is_restatement=True,
14251427
)
14261428

14271429
assert (
@@ -1610,6 +1612,7 @@ def test_merge_scd_type_2_pandas(make_mocked_engine_adapter: t.Callable):
16101612
},
16111613
execution_time=datetime(2020, 1, 1, 0, 0, 0),
16121614
start=datetime(2020, 1, 1, 0, 0, 0),
1615+
is_restatement=True,
16131616
)
16141617

16151618
assert (
@@ -1799,6 +1802,7 @@ def test_scd_type_2_by_column(make_mocked_engine_adapter: t.Callable):
17991802
execution_time=datetime(2020, 1, 1, 0, 0, 0),
18001803
start=datetime(2020, 1, 1, 0, 0, 0),
18011804
extra_col_ignore="testing",
1805+
is_restatement=True,
18021806
)
18031807

18041808
assert (
@@ -1990,6 +1994,7 @@ def test_scd_type_2_by_column_composite_key(make_mocked_engine_adapter: t.Callab
19901994
},
19911995
execution_time=datetime(2020, 1, 1, 0, 0, 0),
19921996
start=datetime(2020, 1, 1, 0, 0, 0),
1997+
is_restatement=True,
19931998
)
19941999
assert (
19952000
parse_one(adapter.cursor.execute.call_args[0][0]).sql()
@@ -2352,6 +2357,7 @@ def test_scd_type_2_by_column_star_check(make_mocked_engine_adapter: t.Callable)
23522357
},
23532358
execution_time=datetime(2020, 1, 1, 0, 0, 0),
23542359
start=datetime(2020, 1, 1, 0, 0, 0),
2360+
is_restatement=True,
23552361
)
23562362

23572363
assert (
@@ -2527,6 +2533,7 @@ def test_scd_type_2_by_column_no_invalidate_hard_deletes(make_mocked_engine_adap
25272533
},
25282534
execution_time=datetime(2020, 1, 1, 0, 0, 0),
25292535
start=datetime(2020, 1, 1, 0, 0, 0),
2536+
is_restatement=True,
25302537
)
25312538

25322539
assert (

0 commit comments

Comments
 (0)