Skip to content

Commit 7b562ba

Browse files
Use average runtime as deadline reference (apache#55088)
Co-authored-by: Ramit Kataria <ramitkat@amazon.com>
1 parent b3dad09 commit 7b562ba

File tree

6 files changed

+315
-17
lines changed

6 files changed

+315
-17
lines changed

airflow-core/docs/howto/deadline-alerts.rst

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,58 @@ Airflow provides several built-in reference points that you can use with Deadlin
104104
``DeadlineReference.FIXED_DATETIME``
105105
Specifies a fixed point in time. Useful when Dags must complete by a specific time.
106106

107+
``DeadlineReference.AVERAGE_RUNTIME``
108+
Calculates deadlines based on the average runtime of previous DAG runs. This reference
109+
analyzes historical execution data to predict when the current run should complete.
110+
The deadline is set to the current time plus the calculated average runtime plus the interval.
111+
If insufficient historical data exists, no deadline is created.
112+
113+
Parameters:
114+
* ``max_runs`` (int, optional): Maximum number of recent DAG runs to analyze. Defaults to 10.
115+
* ``min_runs`` (int, optional): Minimum number of completed runs required to calculate average. Defaults to same value as ``max_runs``.
116+
117+
Example usage:
118+
119+
.. code-block:: python
120+
121+
# Use default settings (analyze up to 10 runs, require 10 runs)
122+
DeadlineReference.AVERAGE_RUNTIME()
123+
124+
# Analyze up to 20 runs but calculate with minimum 5 runs
125+
DeadlineReference.AVERAGE_RUNTIME(max_runs=20, min_runs=5)
126+
127+
# Strict: require exactly 15 runs to calculate
128+
DeadlineReference.AVERAGE_RUNTIME(max_runs=15, min_runs=15)
129+
130+
Here's an example using average runtime:
131+
132+
.. code-block:: python
133+
134+
with DAG(
135+
dag_id="average_runtime_deadline",
136+
deadline=DeadlineAlert(
137+
reference=DeadlineReference.AVERAGE_RUNTIME(max_runs=15, min_runs=5),
138+
interval=timedelta(minutes=30), # Alert if 30 minutes past average runtime
139+
callback=AsyncCallback(
140+
SlackWebhookNotifier,
141+
kwargs={"text": "🚨 DAG {{ dag_run.dag_id }} is running longer than expected!"},
142+
),
143+
),
144+
):
145+
EmptyOperator(task_id="data_processing")
146+
147+
If the calculated historical average was 30 minutes, the timeline for this example would look like this:
148+
149+
::
150+
151+
|------|----------|--------------|--------------|--------|
152+
Queued Start | Deadline
153+
09:00 09:05 09:35 10:05
154+
| | |
155+
|--- Average --|-- Interval --|
156+
(30 min) (30 min)
157+
158+
107159
Here's an example using a fixed datetime:
108160

109161
.. code-block:: python
@@ -166,6 +218,7 @@ Here's an example using the Slack Notifier if the Dag run has not finished withi
166218
):
167219
EmptyOperator(task_id="example_task")
168220
221+
169222
Creating Custom Callbacks
170223
^^^^^^^^^^^^^^^^^^^^^^^^^
171224

airflow-core/src/airflow/models/deadline.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
import sqlalchemy_jsonfield
2828
import uuid6
29-
from sqlalchemy import Column, ForeignKey, Index, Integer, String, and_, select
29+
from sqlalchemy import Column, ForeignKey, Index, Integer, String, and_, func, select, text
3030
from sqlalchemy.exc import SQLAlchemyError
3131
from sqlalchemy.orm import relationship
3232
from sqlalchemy_utils import UUIDType
@@ -283,7 +283,7 @@ class BaseDeadlineReference(LoggingMixin, ABC):
283283
def reference_name(cls: Any) -> str:
284284
return cls.__name__
285285

286-
def evaluate_with(self, *, session: Session, interval: timedelta, **kwargs: Any) -> datetime:
286+
def evaluate_with(self, *, session: Session, interval: timedelta, **kwargs: Any) -> datetime | None:
287287
"""Validate the provided kwargs and evaluate this deadline with the given conditions."""
288288
filtered_kwargs = {k: v for k, v in kwargs.items() if k in self.required_kwargs}
289289

@@ -295,10 +295,11 @@ def evaluate_with(self, *, session: Session, interval: timedelta, **kwargs: Any)
295295
if extra_kwargs := kwargs.keys() - filtered_kwargs.keys():
296296
self.log.debug("Ignoring unexpected parameters: %s", ", ".join(extra_kwargs))
297297

298-
return self._evaluate_with(session=session, **filtered_kwargs) + interval
298+
base_time = self._evaluate_with(session=session, **filtered_kwargs)
299+
return base_time + interval if base_time is not None else None
299300

300301
@abstractmethod
301-
def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime:
302+
def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None:
302303
"""Must be implemented by subclasses to perform the actual evaluation."""
303304
raise NotImplementedError
304305

@@ -366,6 +367,95 @@ def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime:
366367

367368
return _fetch_from_db(DagRun.queued_at, session=session, **kwargs)
368369

370+
@dataclass
371+
class AverageRuntimeDeadline(BaseDeadlineReference):
372+
"""A deadline that calculates the average runtime from past DAG runs."""
373+
374+
DEFAULT_LIMIT = 10
375+
max_runs: int
376+
min_runs: int | None = None
377+
required_kwargs = {"dag_id"}
378+
379+
def __post_init__(self):
380+
if self.min_runs is None:
381+
self.min_runs = self.max_runs
382+
if self.min_runs < 1:
383+
raise ValueError("min_runs must be at least 1")
384+
385+
@provide_session
386+
def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None:
387+
from airflow.models import DagRun
388+
389+
dag_id = kwargs["dag_id"]
390+
391+
# Get database dialect to use appropriate time difference calculation
392+
dialect = session.bind.dialect.name
393+
394+
# Create database-specific expression for calculating duration in seconds
395+
if dialect == "postgresql":
396+
duration_expr = func.extract("epoch", DagRun.end_date - DagRun.start_date)
397+
elif dialect == "mysql":
398+
# Use TIMESTAMPDIFF to get exact seconds like PostgreSQL EXTRACT(epoch FROM ...)
399+
duration_expr = func.timestampdiff(text("SECOND"), DagRun.start_date, DagRun.end_date)
400+
elif dialect == "sqlite":
401+
duration_expr = (func.julianday(DagRun.end_date) - func.julianday(DagRun.start_date)) * 86400
402+
else:
403+
raise ValueError(f"Unsupported database dialect: {dialect}")
404+
405+
# Query for completed DAG runs with both start and end dates
406+
# Order by logical_date descending to get most recent runs first
407+
query = (
408+
select(duration_expr)
409+
.filter(DagRun.dag_id == dag_id, DagRun.start_date.isnot(None), DagRun.end_date.isnot(None))
410+
.order_by(DagRun.logical_date.desc())
411+
)
412+
413+
# Apply max_runs
414+
query = query.limit(self.max_runs)
415+
416+
# Get all durations and calculate average
417+
durations = session.execute(query).scalars().all()
418+
419+
if len(durations) < cast("int", self.min_runs):
420+
logger.info(
421+
"Only %d completed DAG runs found for dag_id: %s (need %d), skipping deadline creation",
422+
len(durations),
423+
dag_id,
424+
self.min_runs,
425+
)
426+
return None
427+
# Convert to float to handle Decimal types from MySQL while preserving precision
428+
# Use Decimal arithmetic for higher precision, then convert to float
429+
from decimal import Decimal
430+
431+
decimal_durations = [Decimal(str(d)) for d in durations]
432+
avg_seconds = float(sum(decimal_durations) / len(decimal_durations))
433+
logger.info(
434+
"Average runtime for dag_id %s (from %d runs): %.2f seconds",
435+
dag_id,
436+
len(durations),
437+
avg_seconds,
438+
)
439+
return timezone.utcnow() + timedelta(seconds=avg_seconds)
440+
441+
def serialize_reference(self) -> dict:
442+
return {
443+
ReferenceModels.REFERENCE_TYPE_FIELD: self.reference_name,
444+
"max_runs": self.max_runs,
445+
"min_runs": self.min_runs,
446+
}
447+
448+
@classmethod
449+
def deserialize_reference(cls, reference_data: dict):
450+
max_runs = reference_data.get("max_runs", cls.DEFAULT_LIMIT)
451+
min_runs = reference_data.get("min_runs", max_runs)
452+
if min_runs < 1:
453+
raise ValueError("min_runs must be at least 1")
454+
return cls(
455+
max_runs=max_runs,
456+
min_runs=min_runs,
457+
)
458+
369459

370460
DeadlineReferenceType = ReferenceModels.BaseDeadlineReference
371461

airflow-core/src/airflow/serialization/serialized_objects.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3270,18 +3270,20 @@ def create_dagrun(
32703270
if self.deadline:
32713271
for deadline in cast("list", self.deadline):
32723272
if isinstance(deadline.reference, DeadlineReference.TYPES.DAGRUN):
3273-
session.add(
3274-
Deadline(
3275-
deadline_time=deadline.reference.evaluate_with(
3276-
session=session,
3277-
interval=deadline.interval,
3278-
dag_id=self.dag_id,
3279-
run_id=run_id,
3280-
),
3281-
callback=deadline.callback,
3282-
dagrun_id=orm_dagrun.id,
3283-
)
3273+
deadline_time = deadline.reference.evaluate_with(
3274+
session=session,
3275+
interval=deadline.interval,
3276+
dag_id=self.dag_id,
3277+
run_id=run_id,
32843278
)
3279+
if deadline_time is not None:
3280+
session.add(
3281+
Deadline(
3282+
deadline_time=deadline_time,
3283+
callback=deadline.callback,
3284+
dagrun_id=orm_dagrun.id,
3285+
)
3286+
)
32853287

32863288
return orm_dagrun
32873289

0 commit comments

Comments
 (0)