2626
2727import sqlalchemy_jsonfield
2828import uuid6
29- from sqlalchemy import Column , ForeignKey , Index , Integer , String , and_ , select
29+ from sqlalchemy import Column , ForeignKey , Index , Integer , String , and_ , func , select , text
3030from sqlalchemy .exc import SQLAlchemyError
3131from sqlalchemy .orm import relationship
3232from sqlalchemy_utils import UUIDType
@@ -283,7 +283,7 @@ class BaseDeadlineReference(LoggingMixin, ABC):
283283 def reference_name (cls : Any ) -> str :
284284 return cls .__name__
285285
286- def evaluate_with (self , * , session : Session , interval : timedelta , ** kwargs : Any ) -> datetime :
286+ def evaluate_with (self , * , session : Session , interval : timedelta , ** kwargs : Any ) -> datetime | None :
287287 """Validate the provided kwargs and evaluate this deadline with the given conditions."""
288288 filtered_kwargs = {k : v for k , v in kwargs .items () if k in self .required_kwargs }
289289
@@ -295,10 +295,11 @@ def evaluate_with(self, *, session: Session, interval: timedelta, **kwargs: Any)
295295 if extra_kwargs := kwargs .keys () - filtered_kwargs .keys ():
296296 self .log .debug ("Ignoring unexpected parameters: %s" , ", " .join (extra_kwargs ))
297297
298- return self ._evaluate_with (session = session , ** filtered_kwargs ) + interval
298+ base_time = self ._evaluate_with (session = session , ** filtered_kwargs )
299+ return base_time + interval if base_time is not None else None
299300
300301 @abstractmethod
301- def _evaluate_with (self , * , session : Session , ** kwargs : Any ) -> datetime :
302+ def _evaluate_with (self , * , session : Session , ** kwargs : Any ) -> datetime | None :
302303 """Must be implemented by subclasses to perform the actual evaluation."""
303304 raise NotImplementedError
304305
@@ -366,6 +367,95 @@ def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime:
366367
367368 return _fetch_from_db (DagRun .queued_at , session = session , ** kwargs )
368369
370+ @dataclass
371+ class AverageRuntimeDeadline (BaseDeadlineReference ):
372+ """A deadline that calculates the average runtime from past DAG runs."""
373+
374+ DEFAULT_LIMIT = 10
375+ max_runs : int
376+ min_runs : int | None = None
377+ required_kwargs = {"dag_id" }
378+
379+ def __post_init__ (self ):
380+ if self .min_runs is None :
381+ self .min_runs = self .max_runs
382+ if self .min_runs < 1 :
383+ raise ValueError ("min_runs must be at least 1" )
384+
385+ @provide_session
386+ def _evaluate_with (self , * , session : Session , ** kwargs : Any ) -> datetime | None :
387+ from airflow .models import DagRun
388+
389+ dag_id = kwargs ["dag_id" ]
390+
391+ # Get database dialect to use appropriate time difference calculation
392+ dialect = session .bind .dialect .name
393+
394+ # Create database-specific expression for calculating duration in seconds
395+ if dialect == "postgresql" :
396+ duration_expr = func .extract ("epoch" , DagRun .end_date - DagRun .start_date )
397+ elif dialect == "mysql" :
398+ # Use TIMESTAMPDIFF to get exact seconds like PostgreSQL EXTRACT(epoch FROM ...)
399+ duration_expr = func .timestampdiff (text ("SECOND" ), DagRun .start_date , DagRun .end_date )
400+ elif dialect == "sqlite" :
401+ duration_expr = (func .julianday (DagRun .end_date ) - func .julianday (DagRun .start_date )) * 86400
402+ else :
403+ raise ValueError (f"Unsupported database dialect: { dialect } " )
404+
405+ # Query for completed DAG runs with both start and end dates
406+ # Order by logical_date descending to get most recent runs first
407+ query = (
408+ select (duration_expr )
409+ .filter (DagRun .dag_id == dag_id , DagRun .start_date .isnot (None ), DagRun .end_date .isnot (None ))
410+ .order_by (DagRun .logical_date .desc ())
411+ )
412+
413+ # Apply max_runs
414+ query = query .limit (self .max_runs )
415+
416+ # Get all durations and calculate average
417+ durations = session .execute (query ).scalars ().all ()
418+
419+ if len (durations ) < cast ("int" , self .min_runs ):
420+ logger .info (
421+ "Only %d completed DAG runs found for dag_id: %s (need %d), skipping deadline creation" ,
422+ len (durations ),
423+ dag_id ,
424+ self .min_runs ,
425+ )
426+ return None
427+ # Convert to float to handle Decimal types from MySQL while preserving precision
428+ # Use Decimal arithmetic for higher precision, then convert to float
429+ from decimal import Decimal
430+
431+ decimal_durations = [Decimal (str (d )) for d in durations ]
432+ avg_seconds = float (sum (decimal_durations ) / len (decimal_durations ))
433+ logger .info (
434+ "Average runtime for dag_id %s (from %d runs): %.2f seconds" ,
435+ dag_id ,
436+ len (durations ),
437+ avg_seconds ,
438+ )
439+ return timezone .utcnow () + timedelta (seconds = avg_seconds )
440+
441+ def serialize_reference (self ) -> dict :
442+ return {
443+ ReferenceModels .REFERENCE_TYPE_FIELD : self .reference_name ,
444+ "max_runs" : self .max_runs ,
445+ "min_runs" : self .min_runs ,
446+ }
447+
448+ @classmethod
449+ def deserialize_reference (cls , reference_data : dict ):
450+ max_runs = reference_data .get ("max_runs" , cls .DEFAULT_LIMIT )
451+ min_runs = reference_data .get ("min_runs" , max_runs )
452+ if min_runs < 1 :
453+ raise ValueError ("min_runs must be at least 1" )
454+ return cls (
455+ max_runs = max_runs ,
456+ min_runs = min_runs ,
457+ )
458+
369459
370460DeadlineReferenceType = ReferenceModels .BaseDeadlineReference
371461
0 commit comments