11import logging
22import typing as t
3+ from types import MappingProxyType
34
45from dagster import (
56 AssetExecutionContext ,
67 ConfigurableResource ,
78 MaterializeResult ,
89)
10+ from dagster ._core .errors import DagsterInvalidPropertyError
911from sqlmesh import Model
1012from sqlmesh .core .context import Context as SQLMeshContext
1113from sqlmesh .core .snapshot import Snapshot , SnapshotInfoLike , SnapshotTableInfo
1214from sqlmesh .utils .dag import DAG
1315from sqlmesh .utils .date import TimeLike
16+ from sqlmesh .utils .errors import SQLMeshError
1417
1518from dagster_sqlmesh .controller .base import (
1619 DEFAULT_CONTEXT_FACTORY ,
@@ -113,20 +116,41 @@ def event_name(self):
113116 return self ._event .__class__ .__name__
114117
115118
119+ class GenericSQLMeshError (Exception ):
120+ pass
121+
122+
123+ class FailedModelError (Exception ):
124+ def __init__ (self , model_name : str , message : str | None ) -> None :
125+ super ().__init__ (message )
126+ self .model_name = model_name
127+ self .message = message
128+
129+
130+ class PlanOrRunFailedError (Exception ):
131+ def __init__ (self , stage : str , message : str , errors : list [Exception ]) -> None :
132+ super ().__init__ (message )
133+ self .stage = stage
134+ self .errors = errors
135+
136+
116137class DagsterSQLMeshEventHandler :
117138 def __init__ (
118139 self ,
119140 context : AssetExecutionContext ,
120141 models_map : dict [str , Model ],
121142 dag : DAG [t .Any ],
122143 prefix : str ,
144+ is_testing : bool = False ,
123145 ) -> None :
124146 self ._models_map = models_map
125147 self ._prefix = prefix
126148 self ._context = context
127149 self ._logger = context .log
128150 self ._tracker = MaterializationTracker (dag .sorted [:], self ._logger )
129151 self ._stage = "plan"
152+ self ._errors : list [Exception ] = []
153+ self ._is_testing = is_testing
130154
131155 def process_events (self , event : console .ConsoleEvent ) -> None :
132156 self .report_event (event )
@@ -150,14 +174,17 @@ def notify_success(
150174 # If the model is not in models_map, we can skip any notification
151175 if model :
152176 output_key = sqlmesh_model_name_to_key (model .name )
153- asset_key = self ._context .asset_key_for_output (output_key )
154- yield MaterializeResult (
155- asset_key = asset_key ,
156- metadata = {
157- "updated" : update_status ,
158- "duration_ms" : 0 ,
159- },
160- )
177+ if not self ._is_testing :
178+ # Stupidly dagster when testing cannot use the following
179+ # method so we must specifically skip this when testing
180+ asset_key = self ._context .asset_key_for_output (output_key )
181+ yield MaterializeResult (
182+ asset_key = asset_key ,
183+ metadata = {
184+ "updated" : update_status ,
185+ "duration_ms" : 0 ,
186+ },
187+ )
161188 notify = self ._tracker .notify_queue_next ()
162189
163190 def report_event (self , event : console .ConsoleEvent ) -> None :
@@ -210,19 +237,22 @@ def report_event(self, event: console.ConsoleEvent) -> None:
210237 if success :
211238 log_context .info ("sqlmesh ran successfully" )
212239 else :
213- log_context .error ("sqlmesh failed" )
214- raise Exception ("sqlmesh failed during run" )
240+ log_context .error ("sqlmesh failed. check collected errors" )
215241 case console .LogError (message = message ):
216242 log_context .error (
217243 f"sqlmesh reported an error: { message } " ,
218244 )
219- case console .LogFailedModels (models = models ):
220- if len (models ) != 0 :
245+ self ._errors .append (GenericSQLMeshError (message ))
246+ case console .LogFailedModels (errors = errors ):
247+ if len (errors ) != 0 :
221248 failed_models = "\n " .join (
222- [f"{ model !s} \n { model .__cause__ !s} " for model in models ]
249+ [f"{ error . node !s} \n { error .__cause__ !s} " for error in errors ]
223250 )
224251 log_context .error (f"sqlmesh failed models: { failed_models } " )
225- raise Exception ("sqlmesh has failed models" )
252+ for error in errors :
253+ self ._errors .append (
254+ FailedModelError (error .node , str (error .__cause__ ))
255+ )
226256 case console .UpdatePromotionProgress (snapshot = snapshot , promoted = promoted ):
227257 log_context .info (
228258 "Promotion progress update" ,
@@ -263,9 +293,18 @@ def log(
263293 def update_stage (self , stage : str ):
264294 self ._stage = stage
265295
296+ @property
297+ def stage (self ) -> str :
298+ return self ._stage
299+
300+ @property
301+ def errors (self ) -> list [Exception ]:
302+ return self ._errors [:]
303+
266304
267305class SQLMeshResource (ConfigurableResource ):
268306 config : SQLMeshContextConfig
307+ is_testing : bool = False
269308
270309 def run (
271310 self ,
@@ -293,25 +332,16 @@ def run(
293332 with controller .instance (environment ) as mesh :
294333 dag = mesh .models_dag ()
295334
296- select_models = []
297-
298335 models = mesh .models ()
299336 models_map = models .copy ()
300337 all_available_models = set (
301338 [model .fqn for model , _ in mesh .non_external_models_dag ()]
302339 )
303- if context .selected_output_names :
304- models_map = {}
305- for key , model in models .items ():
306- if (
307- sqlmesh_model_name_to_key (model .name )
308- in context .selected_output_names
309- ):
310- models_map [key ] = model
311- select_models .append (model .name )
312- selected_models_set = set (models_map .keys ())
313-
314- if all_available_models == selected_models_set :
340+ selected_models_set , models_map , select_models = (
341+ self ._get_selected_models_from_context (context , models )
342+ )
343+
344+ if all_available_models == selected_models_set or select_models is None :
315345 logger .info ("all models selected" )
316346
317347 # Setting this to none to allow sqlmesh to select all models and
@@ -321,24 +351,61 @@ def run(
321351 logger .info (f"selected models: { select_models } " )
322352
323353 event_handler = DagsterSQLMeshEventHandler (
324- context , models_map , dag , "sqlmesh: "
354+ context , models_map , dag , "sqlmesh: " , is_testing = self . is_testing
325355 )
326356
327- for event in mesh .plan_and_run (
328- start = start ,
329- end = end ,
330- select_models = select_models ,
331- restate_models = restate_models ,
332- restate_selected = restate_selected ,
333- skip_run = skip_run ,
334- plan_options = plan_options ,
335- run_options = run_options ,
336- ):
337- logger .debug (f"sqlmesh event: { event } " )
338- event_handler .process_events (event )
339-
357+ try :
358+ for event in mesh .plan_and_run (
359+ start = start ,
360+ end = end ,
361+ select_models = select_models ,
362+ restate_models = restate_models ,
363+ restate_selected = restate_selected ,
364+ skip_run = skip_run ,
365+ plan_options = plan_options ,
366+ run_options = run_options ,
367+ ):
368+ logger .debug (f"sqlmesh event: { event } " )
369+ event_handler .process_events (event )
370+ except SQLMeshError as e :
371+ logger .error (f"sqlmesh error: { e } " )
372+ errors = event_handler .errors
373+ for error in errors :
374+ logger .error (f"sqlmesh encountered the following error during sqlmesh { event_handler .stage } : { error } " )
375+ raise PlanOrRunFailedError (
376+ event_handler .stage ,
377+ f"sqlmesh failed during { event_handler .stage } with { len (event_handler .errors ) + 1 } errors" ,
378+ [e , * event_handler .errors ],
379+ )
340380 yield from event_handler .notify_success (mesh .context )
341381
382+ def _get_selected_models_from_context (
383+ self , context : AssetExecutionContext , models : MappingProxyType [str , Model ]
384+ ) -> tuple [set [str ], dict [str , Model ], list [str ] | None ]:
385+ models_map = models .copy ()
386+ try :
387+ selected_output_names = set (context .selected_output_names )
388+ except (DagsterInvalidPropertyError , AttributeError ) as e :
389+ # Special case for direct execution context when testing. This is related to:
390+ # https://github.com/dagster-io/dagster/issues/23633
391+ if "DirectOpExecutionContext" in str (e ):
392+ context .log .warning ("Caught an error that is likely a direct execution" )
393+ return (set (models_map .keys ()), models_map , None )
394+ else :
395+ raise e
396+
397+ select_models : list [str ] = []
398+ models_map = {}
399+ for key , model in models .items ():
400+ if sqlmesh_model_name_to_key (model .name ) in selected_output_names :
401+ models_map [key ] = model
402+ select_models .append (model .name )
403+ return (
404+ set (models_map .keys ()),
405+ models_map ,
406+ select_models ,
407+ )
408+
342409 def get_controller (
343410 self ,
344411 context_factory : ContextFactory [ContextCls ],
0 commit comments