11import logging
22import typing as t
33
4- from dagster import AssetExecutionContext , ConfigurableResource , MaterializeResult
4+ from dagster import (
5+ AssetExecutionContext ,
6+ ConfigurableResource ,
7+ MaterializeResult ,
8+ )
59from sqlmesh import Model
610from sqlmesh .core .context import Context as SQLMeshContext
7- from sqlmesh .core .snapshot import Snapshot
11+ from sqlmesh .core .snapshot import Snapshot , SnapshotInfoLike , SnapshotTableInfo
812from sqlmesh .utils .dag import DAG
913from sqlmesh .utils .date import TimeLike
1014
@@ -27,28 +31,36 @@ def __init__(self, sorted_dag: list[str], logger: logging.Logger) -> None:
2731 self ._complete_update_status : dict [str , bool ] = {}
2832 self ._sorted_dag = sorted_dag
2933 self ._current_index = 0
34+ self .finished_promotion = False
3035
31- def plan (self , batches : dict [Snapshot , int ]) -> None :
32- self ._batches = batches
33- self ._count : dict [Snapshot , int ] = {}
34-
35- incomplete_names = set ()
36- for snapshot , count in self ._batches .items ():
37- incomplete_names .add (snapshot .name )
38- self ._count [snapshot ] = 0
36+ def init_complete_update_status (self , snapshots : list [SnapshotTableInfo ]) -> None :
37+ planned_model_names = set ()
38+ for snapshot in snapshots :
39+ planned_model_names .add (snapshot .name )
3940
4041 # Anything not in the plan should be listed as completed and queued for
4142 # notification
4243 self ._complete_update_status = {
43- name : False for name in (set (self ._sorted_dag ) - incomplete_names )
44+ name : False for name in (set (self ._sorted_dag ) - planned_model_names )
4445 }
4546
46- def update (self , snapshot : Snapshot , _batch_idx : int ) -> tuple [int , int ]:
47+ def update_promotion (self , snapshot : SnapshotInfoLike , promoted : bool ) -> None :
48+ self ._complete_update_status [snapshot .name ] = promoted
49+
50+ def stop_promotion (self ) -> None :
51+ self .finished_promotion = True
52+
53+ def plan (self , batches : dict [Snapshot , int ]) -> None :
54+ self ._batches = batches
55+ self ._count : dict [Snapshot , int ] = {}
56+
57+ for snapshot , _ in self ._batches .items ():
58+ self ._count [snapshot ] = 0
59+
60+ def update_plan (self , snapshot : Snapshot , _batch_idx : int ) -> tuple [int , int ]:
4761 self ._count [snapshot ] += 1
4862 current_count = self ._count [snapshot ]
4963 expected_count = self ._batches [snapshot ]
50- if self ._batches [snapshot ] == self ._count [snapshot ]:
51- self ._complete_update_status [snapshot .name ] = True
5264 return (current_count , expected_count )
5365
5466 def notify_queue_next (self ) -> tuple [str , bool ] | None :
@@ -110,11 +122,12 @@ def __init__(
110122 self ._tracker = MaterializationTracker (dag .sorted [:], self ._logger )
111123 self ._stage = "plan"
112124
113- def process_events (
114- self , sqlmesh_context : SQLMeshContext , event : console .ConsoleEvent
115- ) -> t .Iterator [MaterializeResult ]:
125+ def process_events (self , event : console .ConsoleEvent ) -> None :
116126 self .report_event (event )
117127
128+ def notify_success (
129+ self , sqlmesh_context : SQLMeshContext
130+ ) -> t .Iterator [MaterializeResult ]:
118131 notify = self ._tracker .notify_queue_next ()
119132 while notify is not None :
120133 completed_name , update_status = notify
@@ -146,6 +159,7 @@ def report_event(self, event: console.ConsoleEvent) -> None:
146159
147160 match event :
148161 case console .StartPlanEvaluation (plan ):
162+ self ._tracker .init_complete_update_status (plan .environment .snapshots )
149163 log_context .info (
150164 "Starting Plan Evaluation" ,
151165 {
@@ -173,7 +187,7 @@ def report_event(self, event: console.ConsoleEvent) -> None:
173187 case console .UpdateSnapshotEvaluationProgress (
174188 snapshot , batch_idx , duration_ms
175189 ):
176- done , expected = self ._tracker .update (snapshot , batch_idx )
190+ done , expected = self ._tracker .update_plan (snapshot , batch_idx )
177191
178192 log_context .info (
179193 "Snapshot progress update" ,
@@ -200,6 +214,21 @@ def report_event(self, event: console.ConsoleEvent) -> None:
200214 [f"{ model !s} \n { model .__cause__ !s} " for model in models ]
201215 )
202216 log_context .error (f"sqlmesh failed models: { failed_models } " )
217+ case console .UpdatePromotionProgress (snapshot , promoted ):
218+ log_context .info (
219+ "Promotion progress update" ,
220+ {
221+ "snapshot" : snapshot .name ,
222+ "promoted" : promoted ,
223+ },
224+ )
225+ self ._tracker .update_promotion (snapshot , promoted )
226+ case console .StopPromotionProgress (success ):
227+ self ._tracker .stop_promotion ()
228+ if success :
229+ log_context .info ("Promotion completed successfully" )
230+ else :
231+ log_context .error ("Promotion failed" )
203232 case _:
204233 log_context .debug ("Received event" )
205234
@@ -237,6 +266,7 @@ def run(
237266 start : TimeLike | None = None ,
238267 end : TimeLike | None = None ,
239268 restate_selected : bool = False ,
269+ skip_run : bool = False ,
240270 plan_options : PlanOptions | None = None ,
241271 run_options : RunOptions | None = None ,
242272 ) -> t .Iterable [MaterializeResult ]:
@@ -287,10 +317,13 @@ def run(
287317 end = end ,
288318 select_models = select_models ,
289319 restate_selected = restate_selected ,
320+ skip_run = skip_run ,
290321 plan_options = plan_options ,
291322 run_options = run_options ,
292323 ):
293- yield from event_handler .process_events (mesh .context , event )
324+ event_handler .process_events (event )
325+
326+ yield from event_handler .notify_success (mesh .context )
294327
295328 def get_controller (
296329 self , log_override : logging .Logger | None = None
0 commit comments