Skip to content

Commit b29e71b

Browse files
authored
Feat!: Skip model evaluation if upstream external model(s) have not changed (#5277)
1 parent 0bf202c commit b29e71b

File tree

13 files changed

+346
-8
lines changed

13 files changed

+346
-8
lines changed

sqlmesh/core/context.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ def __init__(
274274
deployability_index: t.Optional[DeployabilityIndex] = None,
275275
default_dialect: t.Optional[str] = None,
276276
default_catalog: t.Optional[str] = None,
277+
is_restatement: t.Optional[bool] = None,
277278
variables: t.Optional[t.Dict[str, t.Any]] = None,
278279
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
279280
):
@@ -284,6 +285,7 @@ def __init__(
284285
self._default_dialect = default_dialect
285286
self._variables = variables or {}
286287
self._blueprint_variables = blueprint_variables or {}
288+
self._is_restatement = is_restatement
287289

288290
@property
289291
def default_dialect(self) -> t.Optional[str]:
@@ -308,6 +310,10 @@ def gateway(self) -> t.Optional[str]:
308310
"""Returns the gateway name."""
309311
return self.var(c.GATEWAY)
310312

313+
@property
314+
def is_restatement(self) -> t.Optional[bool]:
315+
return self._is_restatement
316+
311317
def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
312318
"""Returns a variable value."""
313319
return self._variables.get(var_name.lower(), default)
@@ -328,6 +334,7 @@ def with_variables(
328334
self.deployability_index,
329335
self._default_dialect,
330336
self._default_catalog,
337+
self._is_restatement,
331338
variables=variables,
332339
blueprint_variables=blueprint_variables,
333340
)

sqlmesh/core/engine_adapter/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class EngineAdapter:
119119
MAX_IDENTIFIER_LENGTH: t.Optional[int] = None
120120
ATTACH_CORRELATION_ID = True
121121
SUPPORTS_QUERY_EXECUTION_TRACKING = False
122+
SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS = False
122123

123124
def __init__(
124125
self,
@@ -2927,6 +2928,9 @@ def _check_identifier_length(self, expression: exp.Expression) -> None:
29272928
f"Identifier name '{name}' (length {name_length}) exceeds {self.dialect.capitalize()}'s max identifier limit of {self.MAX_IDENTIFIER_LENGTH} characters"
29282929
)
29292930

2931+
def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]:
2932+
raise NotImplementedError()
2933+
29302934

29312935
class EngineAdapterWithIndexSupport(EngineAdapter):
29322936
SUPPORTS_INDEXES = True

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,28 @@ def table_exists(self, table_name: TableName) -> bool:
755755
except NotFound:
756756
return False
757757

758+
def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]:
759+
from sqlmesh.utils.date import to_timestamp
760+
761+
datasets_to_tables: t.DefaultDict[str, t.List[str]] = defaultdict(list)
762+
for table_name in table_names:
763+
table = exp.to_table(table_name)
764+
datasets_to_tables[table.db].append(table.name)
765+
766+
results = []
767+
768+
for dataset, tables in datasets_to_tables.items():
769+
query = (
770+
f"SELECT TIMESTAMP_MILLIS(last_modified_time) FROM `{dataset}.__TABLES__` WHERE "
771+
)
772+
for i, table_name in enumerate(tables):
773+
query += f"TABLE_ID = '{table_name}'"
774+
if i < len(tables) - 1:
775+
query += " OR "
776+
results.extend(self.fetchall(query))
777+
778+
return [to_timestamp(row[0]) for row in results]
779+
758780
def _get_table(self, table_name: TableName) -> BigQueryTable:
759781
"""
760782
Returns a BigQueryTable object for the given table name.

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi
5454
SUPPORTS_MANAGED_MODELS = True
5555
CURRENT_CATALOG_EXPRESSION = exp.func("current_database")
5656
SUPPORTS_CREATE_DROP_CATALOG = True
57+
SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS = True
5758
SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA", "TABLE"]
5859
SCHEMA_DIFFER_KWARGS = {
5960
"parameterized_type_defaults": {
@@ -669,3 +670,18 @@ def close(self) -> t.Any:
669670
self._connection_pool.set_attribute(self.SNOWPARK, None)
670671

671672
return super().close()
673+
674+
def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]:
675+
from sqlmesh.utils.date import to_timestamp
676+
677+
num_tables = len(table_names)
678+
679+
query = "SELECT LAST_ALTERED FROM INFORMATION_SCHEMA.TABLES WHERE"
680+
for i, table_name in enumerate(table_names):
681+
table = exp.to_table(table_name)
682+
query += f"""(TABLE_NAME = '{table.name}' AND TABLE_SCHEMA = '{table.db}' AND TABLE_CATALOG = '{table.catalog}')"""
683+
if i < num_tables - 1:
684+
query += " OR "
685+
686+
result = self.fetchall(query)
687+
return [to_timestamp(row[0]) for row in result]

sqlmesh/core/plan/evaluator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
258258
allow_additive_snapshots=plan.allow_additive_models,
259259
selected_snapshot_ids=stage.selected_snapshot_ids,
260260
selected_models=plan.selected_models,
261+
is_restatement=bool(plan.restatements),
261262
)
262263
if errors:
263264
raise PlanError("Plan application failed.")

sqlmesh/core/scheduler.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,9 @@ def evaluate(
251251
**kwargs,
252252
)
253253

254-
self.state_sync.add_interval(snapshot, start, end, is_dev=not is_deployable)
254+
self.state_sync.add_interval(
255+
snapshot, start, end, is_dev=not is_deployable, last_altered_ts=now_timestamp()
256+
)
255257
return audit_results
256258

257259
def run(
@@ -335,6 +337,7 @@ def batch_intervals(
335337
deployability_index: t.Optional[DeployabilityIndex],
336338
environment_naming_info: EnvironmentNamingInfo,
337339
dag: t.Optional[DAG[SnapshotId]] = None,
340+
is_restatement: bool = False,
338341
) -> t.Dict[Snapshot, Intervals]:
339342
dag = dag or snapshots_to_dag(merged_intervals)
340343

@@ -367,6 +370,7 @@ def batch_intervals(
367370
deployability_index,
368371
default_dialect=adapter.dialect,
369372
default_catalog=self.default_catalog,
373+
is_restatement=is_restatement,
370374
)
371375

372376
intervals = self._check_ready_intervals(
@@ -422,6 +426,7 @@ def run_merged_intervals(
422426
run_environment_statements: bool = False,
423427
audit_only: bool = False,
424428
auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {},
429+
is_restatement: bool = False,
425430
) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]:
426431
"""Runs precomputed batches of missing intervals.
427432
@@ -455,9 +460,12 @@ def run_merged_intervals(
455460
snapshot_dag = full_dag.subdag(*selected_snapshot_ids_set)
456461

457462
batched_intervals = self.batch_intervals(
458-
merged_intervals, deployability_index, environment_naming_info, dag=snapshot_dag
463+
merged_intervals,
464+
deployability_index,
465+
environment_naming_info,
466+
dag=snapshot_dag,
467+
is_restatement=is_restatement,
459468
)
460-
461469
self.console.start_evaluation_progress(
462470
batched_intervals,
463471
environment_naming_info,
@@ -956,6 +964,7 @@ def _check_ready_intervals(
956964
python_env=signals.python_env,
957965
dialect=snapshot.model.dialect,
958966
path=snapshot.model._path,
967+
snapshot=snapshot,
959968
kwargs=kwargs,
960969
)
961970
except SQLMeshError as e:

sqlmesh/core/signal.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
from __future__ import annotations
22

3-
3+
import typing as t
44
from sqlmesh.utils import UniqueKeyDict, registry_decorator
55

6+
if t.TYPE_CHECKING:
7+
from sqlmesh.core.context import ExecutionContext
8+
from sqlmesh.core.snapshot.definition import Snapshot
9+
from sqlmesh.utils.date import DatetimeRanges
10+
from sqlmesh.core.snapshot.definition import DeployabilityIndex
11+
612

713
class signal(registry_decorator):
814
"""Specifies a function which intervals are ready from a list of scheduled intervals.
@@ -33,3 +39,39 @@ class signal(registry_decorator):
3339

3440

3541
SignalRegistry = UniqueKeyDict[str, signal]
42+
43+
44+
@signal()
45+
def freshness(batch: DatetimeRanges, snapshot: Snapshot, context: ExecutionContext) -> bool:
46+
adapter = context.engine_adapter
47+
if context.is_restatement or not adapter.SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS:
48+
return True
49+
50+
deployability_index = context.deployability_index or DeployabilityIndex.all_deployable()
51+
52+
last_altered_ts = (
53+
snapshot.last_altered_ts
54+
if deployability_index.is_deployable(snapshot)
55+
else snapshot.dev_last_altered_ts
56+
)
57+
if not last_altered_ts:
58+
return True
59+
60+
parent_snapshots = {context.snapshots[p.name] for p in snapshot.parents}
61+
if len(parent_snapshots) != len(snapshot.node.depends_on) or not all(
62+
p.is_external for p in parent_snapshots
63+
):
64+
# The mismatch can happen if e.g an external model is not registered in the project
65+
return True
66+
67+
# Finding new data means that the upstream depedencies have been altered
68+
# since the last time the model was evaluated
69+
upstream_dep_has_new_data = any(
70+
upstream_last_altered_ts > last_altered_ts
71+
for upstream_last_altered_ts in adapter.get_table_last_modified_ts(
72+
[p.name for p in parent_snapshots]
73+
)
74+
)
75+
76+
# Returning true is a no-op, returning False nullifies the batch so the model will not be evaluated.
77+
return upstream_dep_has_new_data

sqlmesh/core/snapshot/definition.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ class SnapshotIntervals(PydanticModel):
185185
intervals: Intervals = []
186186
dev_intervals: Intervals = []
187187
pending_restatement_intervals: Intervals = []
188+
last_altered_ts: t.Optional[int] = None
189+
dev_last_altered_ts: t.Optional[int] = None
188190

189191
@property
190192
def snapshot_id(self) -> t.Optional[SnapshotId]:
@@ -205,6 +207,12 @@ def add_dev_interval(self, start: int, end: int) -> None:
205207
def add_pending_restatement_interval(self, start: int, end: int) -> None:
206208
self._add_interval(start, end, "pending_restatement_intervals")
207209

210+
def update_last_altered_ts(self, last_altered_ts: t.Optional[int]) -> None:
211+
self._update_last_altered_ts(last_altered_ts, "last_altered_ts")
212+
213+
def update_dev_last_altered_ts(self, last_altered_ts: t.Optional[int]) -> None:
214+
self._update_last_altered_ts(last_altered_ts, "dev_last_altered_ts")
215+
208216
def remove_interval(self, start: int, end: int) -> None:
209217
self._remove_interval(start, end, "intervals")
210218

@@ -224,6 +232,13 @@ def _add_interval(self, start: int, end: int, interval_attr: str) -> None:
224232
target_intervals = merge_intervals([*target_intervals, (start, end)])
225233
setattr(self, interval_attr, target_intervals)
226234

235+
def _update_last_altered_ts(
236+
self, last_altered_ts: t.Optional[int], last_altered_attr: str
237+
) -> None:
238+
if last_altered_ts:
239+
existing_last_altered_ts = getattr(self, last_altered_attr)
240+
setattr(self, last_altered_attr, max(existing_last_altered_ts or 0, last_altered_ts))
241+
227242
def _remove_interval(self, start: int, end: int, interval_attr: str) -> None:
228243
target_intervals = getattr(self, interval_attr)
229244
target_intervals = remove_interval(target_intervals, start, end)
@@ -713,6 +728,10 @@ class Snapshot(PydanticModel, SnapshotInfoMixin):
713728
dev_table_suffix: str = "dev"
714729
table_naming_convention: TableNamingConvention = TableNamingConvention.default
715730
forward_only: bool = False
731+
# Physical table last modified timestamp, not to be confused with the "updated_ts" field
732+
# which is for the snapshot record itself
733+
last_altered_ts: t.Optional[int] = None
734+
dev_last_altered_ts: t.Optional[int] = None
716735

717736
@field_validator("ttl")
718737
@classmethod
@@ -751,6 +770,7 @@ def hydrate_with_intervals_by_version(
751770
)
752771
for interval in snapshot_intervals:
753772
snapshot.merge_intervals(interval)
773+
754774
result.append(snapshot)
755775

756776
return result
@@ -957,12 +977,20 @@ def merge_intervals(self, other: t.Union[Snapshot, SnapshotIntervals]) -> None:
957977
if not apply_effective_from or end <= effective_from_ts:
958978
self.add_interval(start, end)
959979

980+
if other.last_altered_ts:
981+
self.last_altered_ts = max(self.last_altered_ts or 0, other.last_altered_ts)
982+
960983
if self.dev_version == other.dev_version:
961984
# Merge dev intervals if the dev versions match which would mean
962985
# that this and the other snapshot are pointing to the same dev table.
963986
for start, end in other.dev_intervals:
964987
self.add_interval(start, end, is_dev=True)
965988

989+
if other.dev_last_altered_ts:
990+
self.dev_last_altered_ts = max(
991+
self.dev_last_altered_ts or 0, other.dev_last_altered_ts
992+
)
993+
966994
self.pending_restatement_intervals = merge_intervals(
967995
[*self.pending_restatement_intervals, *other.pending_restatement_intervals]
968996
)
@@ -1081,6 +1109,7 @@ def check_ready_intervals(
10811109
python_env=signals.python_env,
10821110
dialect=self.model.dialect,
10831111
path=self.model._path,
1112+
snapshot=self,
10841113
kwargs=kwargs,
10851114
)
10861115
except SQLMeshError as e:
@@ -2421,6 +2450,7 @@ def check_ready_intervals(
24212450
python_env: t.Dict[str, Executable],
24222451
dialect: DialectType = None,
24232452
path: t.Optional[Path] = None,
2453+
snapshot: t.Optional[Snapshot] = None,
24242454
kwargs: t.Optional[t.Dict] = None,
24252455
) -> Intervals:
24262456
checked_intervals: Intervals = []
@@ -2436,6 +2466,7 @@ def check_ready_intervals(
24362466
provided_args=(batch,),
24372467
provided_kwargs=(kwargs or {}),
24382468
context=context,
2469+
snapshot=snapshot,
24392470
)
24402471
except Exception as ex:
24412472
raise SignalEvalError(format_evaluated_code_exception(ex, python_env))

sqlmesh/core/state_sync/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,7 @@ def add_interval(
496496
start: TimeLike,
497497
end: TimeLike,
498498
is_dev: bool = False,
499+
last_altered_ts: t.Optional[int] = None,
499500
) -> None:
500501
"""Add an interval to a snapshot and sync it to the store.
501502
@@ -504,6 +505,7 @@ def add_interval(
504505
start: The start of the interval to add.
505506
end: The end of the interval to add.
506507
is_dev: Indicates whether the given interval is being added while in development mode
508+
last_altered_ts: The timestamp of the last modification of the physical table
507509
"""
508510
start_ts, end_ts = snapshot.inclusive_exclusive(start, end, strict=False, expand=False)
509511
if not snapshot.version:
@@ -516,6 +518,8 @@ def add_interval(
516518
dev_version=snapshot.dev_version,
517519
intervals=intervals if not is_dev else [],
518520
dev_intervals=intervals if is_dev else [],
521+
last_altered_ts=last_altered_ts if not is_dev else None,
522+
dev_last_altered_ts=last_altered_ts if is_dev else None,
519523
)
520524
self.add_snapshots_intervals([snapshot_intervals])
521525

sqlmesh/core/state_sync/db/facade.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,9 @@ def add_interval(
381381
start: TimeLike,
382382
end: TimeLike,
383383
is_dev: bool = False,
384+
last_altered_ts: t.Optional[int] = None,
384385
) -> None:
385-
super().add_interval(snapshot, start, end, is_dev)
386+
super().add_interval(snapshot, start, end, is_dev, last_altered_ts)
386387

387388
@transactional()
388389
def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None:

0 commit comments

Comments
 (0)