Skip to content

Commit 2090656

Browse files
authored
Exploring possibility of moving DagAttributeTypes to execution API spec (#61251)
1 parent 027c7cf commit 2090656

File tree

6 files changed

+17
-6
lines changed

6 files changed

+17
-6
lines changed

airflow-core/src/airflow/api_fastapi/execution_api/app.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ def get_extra_schemas() -> dict[str, dict]:
265265
"""Get all the extra schemas that are not part of the main FastAPI app."""
266266
from airflow.api_fastapi.execution_api.datamodels.taskinstance import TaskInstance
267267
from airflow.executors.workloads import BundleInfo
268+
from airflow.serialization.enums import DagAttributeTypes
268269
from airflow.task.trigger_rule import TriggerRule
269270
from airflow.task.weight_rule import WeightRule
270271
from airflow.utils.state import TaskInstanceState, TerminalTIState
@@ -278,6 +279,11 @@ def get_extra_schemas() -> dict[str, dict]:
278279
"TaskInstanceState": {"type": "string", "enum": list(TaskInstanceState)},
279280
"WeightRule": {"type": "string", "enum": list(WeightRule)},
280281
"TriggerRule": {"type": "string", "enum": list(TriggerRule)},
282+
"DagAttributeTypes": {
283+
"type": "string",
284+
"enum": [DagAttributeTypes.OP.value, DagAttributeTypes.TASK_GROUP.value],
285+
"x-enum-varnames": [DagAttributeTypes.OP.name, DagAttributeTypes.TASK_GROUP.name],
286+
},
281287
}
282288

283289

task-sdk/src/airflow/sdk/api/datamodels/_generated.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,11 @@ class TriggerRule(str, Enum):
485485
ALL_SKIPPED = "all_skipped"
486486

487487

488+
class DagAttributeTypes(str, Enum):
489+
OP = "operator"
490+
TASK_GROUP = "taskgroup"
491+
492+
488493
class AssetReferenceAssetEventDagRun(BaseModel):
489494
"""
490495
Schema for AssetModel used in AssetEventDagRunReference.

task-sdk/src/airflow/sdk/bases/operator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,13 @@ def db_safe_priority(priority_weight: int) -> int:
8989
import jinja2
9090
from typing_extensions import Self
9191

92+
from airflow.sdk.api.datamodels._generated import DagAttributeTypes
9293
from airflow.sdk.bases.operatorlink import BaseOperatorLink
9394
from airflow.sdk.definitions.context import Context
9495
from airflow.sdk.definitions.dag import DAG
9596
from airflow.sdk.definitions.operator_resources import Resources
9697
from airflow.sdk.definitions.taskgroup import TaskGroup
9798
from airflow.sdk.definitions.xcom_arg import XComArg
98-
from airflow.serialization.enums import DagAttributeTypes
9999
from airflow.task.priority_strategy import PriorityWeightStrategy
100100
from airflow.triggers.base import BaseTrigger, StartTriggerArgs
101101

@@ -1554,7 +1554,7 @@ def prepare_for_execution(self) -> Self:
15541554

15551555
def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
15561556
"""Serialize; required by DAGNode."""
1557-
from airflow.serialization.enums import DagAttributeTypes
1557+
from airflow.sdk.api.datamodels._generated import DagAttributeTypes
15581558

15591559
return DagAttributeTypes.OP, self.task_id
15601560

task-sdk/src/airflow/sdk/definitions/_internal/node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@
2727
from airflow.sdk.definitions._internal.mixins import DependencyMixin
2828

2929
if TYPE_CHECKING:
30+
from airflow.sdk.api.datamodels._generated import DagAttributeTypes
3031
from airflow.sdk.definitions.dag import DAG
3132
from airflow.sdk.definitions.edges import EdgeModifier
3233
from airflow.sdk.definitions.taskgroup import TaskGroup # noqa: F401
3334
from airflow.sdk.types import Operator # noqa: F401
34-
from airflow.serialization.enums import DagAttributeTypes
3535

3636

3737
KEY_REGEX = re.compile(r"^[\w.-]+$")

task-sdk/src/airflow/sdk/definitions/mappedoperator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import methodtools
2828
from lazy_object_proxy import Proxy
2929

30+
from airflow.sdk.api.datamodels._generated import DagAttributeTypes
3031
from airflow.sdk.bases.xcom import BaseXCom
3132
from airflow.sdk.definitions._internal.abstractoperator import (
3233
DEFAULT_EXECUTOR,
@@ -50,7 +51,6 @@
5051
is_mappable,
5152
)
5253
from airflow.sdk.definitions._internal.types import NOTSET
53-
from airflow.serialization.enums import DagAttributeTypes
5454

5555
if TYPE_CHECKING:
5656
import datetime

task-sdk/src/airflow/sdk/definitions/taskgroup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@
3636
)
3737

3838
if TYPE_CHECKING:
39+
from airflow.sdk.api.datamodels._generated import DagAttributeTypes
3940
from airflow.sdk.bases.operator import BaseOperator
4041
from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
4142
from airflow.sdk.definitions._internal.expandinput import DictOfListsExpandInput, ListOfDictsExpandInput
4243
from airflow.sdk.definitions._internal.mixins import DependencyMixin
4344
from airflow.sdk.definitions.dag import DAG
4445
from airflow.sdk.definitions.edges import EdgeModifier
4546
from airflow.sdk.types import Operator
46-
from airflow.serialization.enums import DagAttributeTypes
4747

4848

4949
def _default_parent_group() -> TaskGroup | None:
@@ -493,7 +493,7 @@ def get_child_by_label(self, label: str) -> DAGNode:
493493

494494
def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
495495
"""Serialize task group; required by DagNode."""
496-
from airflow.serialization.enums import DagAttributeTypes
496+
from airflow.sdk.api.datamodels._generated import DagAttributeTypes
497497
from airflow.serialization.serialized_objects import TaskGroupSerialization
498498

499499
return (

0 commit comments

Comments
 (0)