Skip to content

Commit 67b71d3

Browse files
authored
Remove top-level SDK reference in Core (#59817)
1 parent 8c3b26e commit 67b71d3

File tree

161 files changed

+266
-367
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

161 files changed

+266
-367
lines changed

airflow-core/src/airflow/api_fastapi/core_api/services/ui/connections.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
ConnectionHookMetaData,
2828
StandardHookFields,
2929
)
30-
from airflow.sdk import Param
30+
from airflow.serialization.definitions.param import SerializedParam
3131

3232
if TYPE_CHECKING:
3333
from airflow.providers_manager import ConnectionFormWidgetInfo, HookInfo
@@ -79,7 +79,7 @@ def __init__(
7979
for v in validators:
8080
if isinstance(v, HookMetaService.MockEnum):
8181
enum = {"enum": v.allowed_values}
82-
self.param = Param(
82+
self.param = SerializedParam(
8383
default=default,
8484
title=label,
8585
description=description or None,

airflow-core/src/airflow/cli/commands/dag_command.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from sqlalchemy.orm import Session
6161

6262
from airflow import DAG
63+
from airflow.serialization.definitions.dag import SerializedDAG
6364
from airflow.timetables.base import DataInterval
6465

6566
DAG_DETAIL_FIELDS = {*DAGResponse.model_fields, *DAGResponse.model_computed_fields}
@@ -656,7 +657,7 @@ def dag_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> No
656657
)
657658
).all()
658659

659-
dot_graph = render_dag(dag, tis=list(tis))
660+
dot_graph = render_dag(cast("SerializedDAG", dag), tis=list(tis))
660661
print()
661662
if filename:
662663
_save_dot_to_file(dot_graph, filename)

airflow-core/src/airflow/models/connection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from airflow.exceptions import AirflowException, AirflowNotFoundException
3737
from airflow.models.base import ID_LEN, Base
3838
from airflow.models.crypto import get_fernet
39-
from airflow.sdk import SecretCache
4039
from airflow.utils.helpers import prune_dict
4140
from airflow.utils.log.logging_mixin import LoggingMixin
4241
from airflow.utils.session import NEW_SESSION, provide_session
@@ -531,6 +530,8 @@ def get_connection_from_secrets(cls, conn_id: str) -> Connection:
531530

532531
# check cache first
533532
# enabled only if SecretCache.init() has been called first
533+
from airflow.sdk import SecretCache
534+
534535
try:
535536
uri = SecretCache.get_connection_uri(conn_id)
536537
return Connection(conn_id=conn_id, uri=uri)

airflow-core/src/airflow/models/dag.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,10 @@
2121
from collections import defaultdict
2222
from collections.abc import Callable, Collection
2323
from datetime import datetime, timedelta
24-
from typing import TYPE_CHECKING, Any, Union, cast
24+
from typing import TYPE_CHECKING, Any, cast
2525

2626
import pendulum
2727
import sqlalchemy_jsonfield
28-
from dateutil.relativedelta import relativedelta
2928
from sqlalchemy import (
3029
Boolean,
3130
Float,
@@ -61,7 +60,6 @@
6160
from airflow.timetables.base import DataInterval, Timetable
6261
from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable
6362
from airflow.timetables.simple import AssetTriggeredTimetable, NullTimetable, OnceTimetable
64-
from airflow.utils.context import Context
6563
from airflow.utils.session import NEW_SESSION, provide_session
6664
from airflow.utils.sqlalchemy import UtcDateTime, mapped_column, with_row_locks
6765
from airflow.utils.state import DagRunState
@@ -70,6 +68,9 @@
7068
if TYPE_CHECKING:
7169
from typing import TypeAlias
7270

71+
from dateutil.relativedelta import relativedelta
72+
73+
from airflow.sdk import Context
7374
from airflow.serialization.definitions.assets import (
7475
SerializedAsset,
7576
SerializedAssetAlias,
@@ -78,21 +79,20 @@
7879
from airflow.serialization.definitions.dag import SerializedDAG
7980

8081
UKey: TypeAlias = SerializedAssetUniqueKey
82+
DagStateChangeCallback = Callable[[Context], None]
83+
ScheduleInterval = None | str | timedelta | relativedelta
84+
85+
ScheduleArg = (
86+
ScheduleInterval
87+
| Timetable
88+
| "SerializedAssetBase"
89+
| Collection["SerializedAsset" | "SerializedAssetAlias"]
90+
)
8191

8292
log = logging.getLogger(__name__)
8393

8494
TAG_MAX_LEN = 100
8595

86-
DagStateChangeCallback = Callable[[Context], None]
87-
ScheduleInterval = None | str | timedelta | relativedelta
88-
89-
ScheduleArg = Union[
90-
ScheduleInterval,
91-
Timetable,
92-
"SerializedAssetBase",
93-
Collection[Union["SerializedAsset", "SerializedAssetAlias"]],
94-
]
95-
9696

9797
def infer_automated_data_interval(timetable: Timetable, logical_date: datetime) -> DataInterval:
9898
"""

airflow-core/src/airflow/models/taskinstance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,10 @@
114114
from airflow.api_fastapi.execution_api.datamodels.asset import AssetProfile
115115
from airflow.models.dag import DagModel
116116
from airflow.models.dagrun import DagRun
117+
from airflow.sdk import Context
117118
from airflow.serialization.definitions.dag import SerializedDAG
118119
from airflow.serialization.definitions.mappedoperator import Operator
119120
from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
120-
from airflow.utils.context import Context
121121

122122

123123
PAST_DEPENDS_MET = "past_depends_met"

airflow-core/src/airflow/models/variable.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from airflow.configuration import conf, ensure_secrets_loaded
3333
from airflow.models.base import ID_LEN, Base
3434
from airflow.models.crypto import get_fernet
35-
from airflow.sdk import SecretCache
3635
from airflow.secrets.metastore import MetastoreBackend
3736
from airflow.utils.log.logging_mixin import LoggingMixin
3837
from airflow.utils.session import NEW_SESSION, create_session, provide_session
@@ -238,6 +237,8 @@ def set(
238237
)
239238

240239
# check if the secret exists in the custom secrets' backend.
240+
from airflow.sdk import SecretCache
241+
241242
Variable.check_for_write_conflict(key=key)
242243
if serialize_json:
243244
stored_value = json.dumps(value, indent=2)
@@ -428,6 +429,8 @@ def delete(key: str, team_name: str | None = None, session: Session | None = Non
428429
"Multi-team mode is not configured in the Airflow environment but the task trying to delete the variable belongs to a team"
429430
)
430431

432+
from airflow.sdk import SecretCache
433+
431434
ctx: contextlib.AbstractContextManager
432435
if session is not None:
433436
ctx = contextlib.nullcontext(session)
@@ -494,6 +497,8 @@ def get_variable_from_secrets(key: str, team_name: str | None = None) -> str | N
494497
:param team_name: Team name associated to the task trying to access the variable (if any)
495498
:return: Variable Value
496499
"""
500+
from airflow.sdk import SecretCache
501+
497502
# Disable cache if the variable belongs to a team. We might enable it later
498503
if not team_name:
499504
# check cache first

airflow-core/src/airflow/serialization/definitions/mappedoperator.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030

3131
from airflow.exceptions import AirflowException, NotMapped
3232
from airflow.sdk import BaseOperator as TaskSDKBaseOperator
33-
from airflow.sdk.definitions._internal.abstractoperator import DEFAULT_RETRY_DELAY_MULTIPLIER
3433
from airflow.sdk.definitions.mappedoperator import MappedOperator as TaskSDKMappedOperator
3534
from airflow.serialization.definitions.baseoperator import DEFAULT_OPERATOR_DEPS, SerializedBaseOperator
3635
from airflow.serialization.definitions.node import DAGNode
@@ -288,10 +287,6 @@ def retry_exponential_backoff(self) -> float:
288287
def max_retry_delay(self) -> datetime.timedelta | float | None:
289288
return self._get_partial_kwargs_or_operator_default("max_retry_delay")
290289

291-
@property
292-
def retry_delay_multiplier(self) -> float:
293-
return float(self.partial_kwargs.get("retry_delay_multiplier", DEFAULT_RETRY_DELAY_MULTIPLIER))
294-
295290
@property
296291
def weight_rule(self) -> PriorityWeightStrategy:
297292
from airflow.serialization.definitions.baseoperator import SerializedBaseOperator

airflow-core/src/airflow/utils/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from airflow._shared.timezones import timezone
3838
from airflow.dag_processing.bundles.manager import DagBundlesManager
3939
from airflow.exceptions import AirflowException
40-
from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
4140
from airflow.utils import cli_action_loggers
4241
from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler
4342
from airflow.utils.platform import getuser, is_terminal_support_colors
@@ -274,6 +273,7 @@ def get_bagged_dag(bundle_names: list | None, dag_id: str, dagfile_path: str | N
274273
dags folder.
275274
"""
276275
from airflow.dag_processing.dagbag import DagBag, sync_bag_to_db
276+
from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
277277

278278
manager = DagBundlesManager()
279279
for bundle_name in bundle_names or ():

airflow-core/src/airflow/utils/context.py

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,22 @@
1919

2020
from __future__ import annotations
2121

22-
from typing import TYPE_CHECKING, Any, cast
22+
import warnings
23+
from typing import Any
2324

2425
from sqlalchemy import select
2526

2627
from airflow.models.asset import AssetModel
27-
from airflow.sdk import Asset, Context
28+
from airflow.sdk import Asset
2829
from airflow.sdk.execution_time.context import (
2930
ConnectionAccessor as ConnectionAccessorSDK,
3031
OutletEventAccessors as OutletEventAccessorsSDK,
3132
VariableAccessor as VariableAccessorSDK,
3233
)
3334
from airflow.serialization.definitions.notset import NOTSET, is_arg_set
35+
from airflow.utils.deprecation_tools import DeprecatedImportWarning
3436
from airflow.utils.session import create_session
3537

36-
if TYPE_CHECKING:
37-
from collections.abc import Container
38-
3938
# NOTE: Please keep this in sync with the following:
4039
# * Context in task-sdk/src/airflow/sdk/definitions/context.py
4140
# * Table in docs/apache-airflow/templates-ref.rst
@@ -141,30 +140,17 @@ def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset
141140
return Asset(name=asset.name, uri=asset.uri, group=asset.group, extra=asset.extra)
142141

143142

144-
def context_merge(context: Context, *args: Any, **kwargs: Any) -> None:
145-
"""
146-
Merge parameters into an existing context.
147-
148-
Like ``dict.update()`` , this take the same parameters, and updates
149-
``context`` in-place.
150-
151-
This is implemented as a free function because the ``Context`` type is
152-
"faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom
153-
functions.
143+
def __getattr__(name: str):
144+
if name in ("Context", "context_copy_partial", "context_merge"):
145+
warnings.warn(
146+
"Importing Context from airflow.utils.context is deprecated and will "
147+
"be removed in the future. Please import it from airflow.sdk instead.",
148+
DeprecatedImportWarning,
149+
stacklevel=2,
150+
)
154151

155-
:meta private:
156-
"""
157-
if not context:
158-
context = Context()
152+
import airflow.sdk.definitions.context as sdk
159153

160-
context.update(*args, **kwargs)
154+
return getattr(sdk, name)
161155

162-
163-
def context_copy_partial(source: Context, keys: Container[str]) -> Context:
164-
"""
165-
Create a context by copying items under selected keys in ``source``.
166-
167-
:meta private:
168-
"""
169-
new = {k: v for k, v in source.items() if k in keys}
170-
return cast("Context", new)
156+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

airflow-core/src/airflow/utils/dag_edges.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,23 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19-
from typing import TYPE_CHECKING, cast
19+
from typing import TYPE_CHECKING, Any
2020

21-
from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
22-
from airflow.serialization.definitions.baseoperator import SerializedBaseOperator
23-
from airflow.serialization.definitions.dag import SerializedDAG
24-
from airflow.serialization.definitions.mappedoperator import SerializedMappedOperator
21+
from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
22+
23+
# Also support SDK types if possible.
24+
try:
25+
from airflow.sdk import TaskGroup
26+
except ImportError:
27+
TaskGroup = SerializedTaskGroup # type: ignore[misc]
2528

2629
if TYPE_CHECKING:
27-
from airflow.sdk import DAG
2830
from airflow.serialization.definitions.dag import SerializedDAG
2931
from airflow.serialization.definitions.mappedoperator import Operator
32+
from airflow.serialization.definitions.node import DAGNode
3033

3134

32-
def dag_edges(dag: DAG | SerializedDAG):
35+
def dag_edges(dag: SerializedDAG):
3336
"""
3437
Create the list of edges needed to construct the Graph view.
3538
@@ -62,9 +65,10 @@ def dag_edges(dag: DAG | SerializedDAG):
6265

6366
task_group_map = dag.task_group.get_task_group_dict()
6467

65-
def collect_edges(task_group):
68+
def collect_edges(task_group: DAGNode) -> None:
6669
"""Update edges_to_add and edges_to_skip according to TaskGroups."""
67-
if isinstance(task_group, (AbstractOperator, SerializedBaseOperator, SerializedMappedOperator)):
70+
child: DAGNode
71+
if not isinstance(task_group, (TaskGroup, SerializedTaskGroup)):
6872
return
6973

7074
for target_id in task_group.downstream_group_ids:
@@ -111,9 +115,7 @@ def collect_edges(task_group):
111115
edges = set()
112116
setup_teardown_edges = set()
113117

114-
# TODO (GH-52141): 'roots' in scheduler needs to return scheduler types
115-
# instead, but currently it inherits SDK's DAG.
116-
tasks_to_trace = cast("list[Operator]", dag.roots)
118+
tasks_to_trace = dag.roots
117119
while tasks_to_trace:
118120
tasks_to_trace_next: list[Operator] = []
119121
for task in tasks_to_trace:
@@ -130,7 +132,7 @@ def collect_edges(task_group):
130132
# Build result dicts with the two ends of the edge, plus any extra metadata
131133
# if we have it.
132134
for source_id, target_id in sorted(edges.union(edges_to_add) - edges_to_skip):
133-
record = {"source_id": source_id, "target_id": target_id}
135+
record: dict[str, Any] = {"source_id": source_id, "target_id": target_id}
134136
label = dag.get_edge_info(source_id, target_id).get("label")
135137
if (source_id, target_id) in setup_teardown_edges:
136138
record["is_setup_teardown"] = True

0 commit comments

Comments
 (0)