Skip to content

Commit e9c2a6f

Browse files
authored
Revert "Remove PriorityWeightStrategy reference in SDK" (#59828)
* Revert "Remove PriorityWeightStrategy reference in SDK (#59780)" This reverts commit 60b4ed4. * Tip by TP
1 parent 87ba445 commit e9c2a6f

File tree

13 files changed

+136
-168
lines changed

13 files changed

+136
-168
lines changed

airflow-core/newsfragments/59780.significant.rst

Lines changed: 0 additions & 4 deletions
This file was deleted.

airflow-core/src/airflow/serialization/serialized_objects.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@
9696
PriorityWeightStrategy,
9797
airflow_priority_weight_strategies,
9898
airflow_priority_weight_strategies_classes,
99-
validate_and_load_priority_weight_strategy,
10099
)
101100
from airflow.timetables.base import DagRunInfo, Timetable
102101
from airflow.triggers.base import BaseTrigger, StartTriggerArgs
@@ -250,15 +249,15 @@ def decode_partition_mapper(var: dict[str, Any]) -> PartitionMapper:
250249
return partition_mapper_class.deserialize(var[Encoding.VAR])
251250

252251

253-
def encode_priority_weight_strategy(var: PriorityWeightStrategy | str) -> str:
252+
def encode_priority_weight_strategy(var: PriorityWeightStrategy) -> str:
254253
"""
255254
Encode a priority weight strategy instance.
256255
257256
In this version, we only store the importable string, so the class should not wait
258257
for any parameters to be passed to it. If you need to store the parameters, you
259258
should store them in the class itself.
260259
"""
261-
priority_weight_strategy_class = type(validate_and_load_priority_weight_strategy(var))
260+
priority_weight_strategy_class = type(var)
262261
if priority_weight_strategy_class in airflow_priority_weight_strategies_classes:
263262
return airflow_priority_weight_strategies_classes[priority_weight_strategy_class]
264263
importable_string = qualname(priority_weight_strategy_class)

airflow-core/src/airflow/task/priority_strategy.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020
from __future__ import annotations
2121

2222
from abc import ABC, abstractmethod
23-
from typing import TYPE_CHECKING
23+
from typing import TYPE_CHECKING, Any
2424

25-
from airflow._shared.module_loading import qualname
2625
from airflow.task.weight_rule import WeightRule
2726

2827
if TYPE_CHECKING:
@@ -42,22 +41,46 @@ class PriorityWeightStrategy(ABC):
4241
"""
4342

4443
@abstractmethod
45-
def get_weight(self, ti: TaskInstance) -> int:
44+
def get_weight(self, ti: TaskInstance):
4645
"""Get the priority weight of a task."""
47-
raise NotImplementedError("must be implemented by a subclass")
46+
...
47+
48+
@classmethod
49+
def deserialize(cls, data: dict[str, Any]) -> PriorityWeightStrategy:
50+
"""
51+
Deserialize a priority weight strategy from data.
52+
53+
This is called when a serialized DAG is deserialized. ``data`` will be whatever
54+
was returned by ``serialize`` during DAG serialization. The default
55+
implementation constructs the priority weight strategy without any arguments.
56+
"""
57+
return cls(**data)
58+
59+
def serialize(self) -> dict[str, Any]:
60+
"""
61+
Serialize the priority weight strategy for JSON encoding.
62+
63+
This is called during DAG serialization to store priority weight strategy information
64+
in the database. This should return a JSON-serializable dict that will be fed into
65+
``deserialize`` when the DAG is deserialized. The default implementation returns
66+
an empty dict.
67+
"""
68+
return {}
4869

4970
def __eq__(self, other: object) -> bool:
5071
"""Equality comparison."""
51-
return isinstance(other, type(self))
72+
if not isinstance(other, type(self)):
73+
return False
74+
return self.serialize() == other.serialize()
5275

53-
def __hash__(self) -> int:
54-
return hash(None)
76+
def __hash__(self):
77+
return hash(self.serialize())
5578

5679

5780
class _AbsolutePriorityWeightStrategy(PriorityWeightStrategy):
5881
"""Priority weight strategy that uses the task's priority weight directly."""
5982

60-
def get_weight(self, ti: TaskInstance) -> int:
83+
def get_weight(self, ti: TaskInstance):
6184
if TYPE_CHECKING:
6285
assert ti.task
6386
return ti.task.priority_weight
@@ -81,7 +104,7 @@ def get_weight(self, ti: TaskInstance) -> int:
81104
class _UpstreamPriorityWeightStrategy(PriorityWeightStrategy):
82105
"""Priority weight strategy that uses the sum of the priority weights of all upstream tasks."""
83106

84-
def get_weight(self, ti: TaskInstance) -> int:
107+
def get_weight(self, ti: TaskInstance):
85108
if TYPE_CHECKING:
86109
assert ti.task
87110
dag = ti.task.get_dag()
@@ -93,19 +116,14 @@ def get_weight(self, ti: TaskInstance) -> int:
93116

94117

95118
airflow_priority_weight_strategies: dict[str, type[PriorityWeightStrategy]] = {
96-
qualname(_AbsolutePriorityWeightStrategy): _AbsolutePriorityWeightStrategy,
97-
qualname(_DownstreamPriorityWeightStrategy): _DownstreamPriorityWeightStrategy,
98-
qualname(_UpstreamPriorityWeightStrategy): _UpstreamPriorityWeightStrategy,
99119
WeightRule.ABSOLUTE: _AbsolutePriorityWeightStrategy,
100120
WeightRule.DOWNSTREAM: _DownstreamPriorityWeightStrategy,
101121
WeightRule.UPSTREAM: _UpstreamPriorityWeightStrategy,
102122
}
103123

104124

105125
airflow_priority_weight_strategies_classes = {
106-
_AbsolutePriorityWeightStrategy: WeightRule.ABSOLUTE,
107-
_DownstreamPriorityWeightStrategy: WeightRule.DOWNSTREAM,
108-
_UpstreamPriorityWeightStrategy: WeightRule.UPSTREAM,
126+
cls: name for name, cls in airflow_priority_weight_strategies.items()
109127
}
110128

111129

@@ -121,6 +139,7 @@ def validate_and_load_priority_weight_strategy(
121139
122140
:meta private:
123141
"""
142+
from airflow._shared.module_loading import qualname
124143
from airflow.serialization.serialized_objects import _get_registered_priority_weight_strategy
125144

126145
if priority_weight_strategy is None:

airflow-core/tests/unit/jobs/test_triggerer_job.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -130,17 +130,12 @@ def create_trigger_in_db(session, trigger, operator=None):
130130
operator = BaseOperator(task_id="test_ti", dag=dag)
131131
session.add(dag_model)
132132

133-
lazy_serdag = LazyDeserializedDAG.from_dag(dag)
134-
SerializedDagModel.write_dag(lazy_serdag, bundle_name=bundle_name)
133+
SerializedDagModel.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name=bundle_name)
135134
session.add(run)
136135
session.add(trigger_orm)
137136
session.flush()
138137
dag_version = DagVersion.get_latest_version(dag.dag_id)
139-
task_instance = TaskInstance(
140-
lazy_serdag._real_dag.get_task(operator.task_id),
141-
run_id=run.run_id,
142-
dag_version_id=dag_version.id,
143-
)
138+
task_instance = TaskInstance(operator, run_id=run.run_id, dag_version_id=dag_version.id)
144139
task_instance.trigger_id = trigger_orm.id
145140
session.add(task_instance)
146141
session.commit()
@@ -445,8 +440,7 @@ async def test_trigger_kwargs_serialization_cleanup(self, session):
445440

446441

447442
@pytest.mark.asyncio
448-
@pytest.mark.usefixtures("testing_dag_bundle")
449-
async def test_trigger_create_race_condition_38599(session, supervisor_builder):
443+
async def test_trigger_create_race_condition_38599(session, supervisor_builder, testing_dag_bundle):
450444
"""
451445
This verifies the resolution of race condition documented in github issue #38599.
452446
More details in the issue description.
@@ -471,17 +465,14 @@ async def test_trigger_create_race_condition_38599(session, supervisor_builder):
471465
session.flush()
472466

473467
bundle_name = "testing"
474-
with DAG(dag_id="test-dag") as dag:
475-
task = PythonOperator(task_id="dummy-task", python_callable=print)
468+
dag = DAG(dag_id="test-dag")
476469
dm = DagModel(dag_id="test-dag", bundle_name=bundle_name)
477470
session.add(dm)
478-
479-
lazy_serdag = LazyDeserializedDAG.from_dag(dag)
480-
SerializedDagModel.write_dag(lazy_serdag, bundle_name=bundle_name)
471+
SerializedDagModel.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name=bundle_name)
481472
dag_run = DagRun(dag.dag_id, run_id="abc", run_type="none", run_after=timezone.utcnow())
482473
dag_version = DagVersion.get_latest_version(dag.dag_id)
483474
ti = TaskInstance(
484-
lazy_serdag._real_dag.get_task(task.task_id),
475+
PythonOperator(task_id="dummy-task", python_callable=print),
485476
run_id=dag_run.run_id,
486477
state=TaskInstanceState.DEFERRED,
487478
dag_version_id=dag_version.id,

airflow-core/tests/unit/models/test_mappedoperator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from airflow.providers.standard.operators.python import PythonOperator
3535
from airflow.sdk import DAG, BaseOperator, TaskGroup, setup, task, task_group, teardown
3636
from airflow.serialization.definitions.baseoperator import SerializedBaseOperator
37+
from airflow.task.priority_strategy import PriorityWeightStrategy
3738
from airflow.task.trigger_rule import TriggerRule
3839
from airflow.utils.state import TaskInstanceState
3940

@@ -1523,7 +1524,7 @@ def test_properties(
15231524
assert op.pool == SerializedBaseOperator.pool
15241525
assert op.pool_slots == SerializedBaseOperator.pool_slots
15251526
assert op.priority_weight == SerializedBaseOperator.priority_weight
1526-
assert op.weight_rule == "downstream"
1527+
assert isinstance(op.weight_rule, PriorityWeightStrategy)
15271528
assert op.email == email
15281529
assert op.execution_timeout == execution_timeout
15291530
assert op.retry_delay == retry_delay

airflow-core/tests/unit/models/test_taskinstance.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
from airflow.serialization.definitions.assets import SerializedAsset
8383
from airflow.serialization.definitions.dag import SerializedDAG
8484
from airflow.serialization.encoders import ensure_serialized_asset
85-
from airflow.serialization.serialized_objects import OperatorSerialization, create_scheduler_operator
85+
from airflow.serialization.serialized_objects import OperatorSerialization
8686
from airflow.ti_deps.dep_context import DepContext
8787
from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
8888
from airflow.ti_deps.dependencies_states import RUNNABLE_STATES
@@ -2812,7 +2812,7 @@ def mock_policy(task_instance: TaskInstance):
28122812

28132813
monkeypatch.setattr("airflow.models.taskinstance.task_instance_mutation_hook", mock_policy)
28142814

2815-
sdk_task = EmptyOperator(
2815+
task = EmptyOperator(
28162816
task_id="empty",
28172817
queue=default_queue,
28182818
pool="test_pool1",
@@ -2822,28 +2822,27 @@ def mock_policy(task_instance: TaskInstance):
28222822
retries=30,
28232823
executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}},
28242824
)
2825-
ser_task = create_scheduler_operator(sdk_task)
2826-
ti = TI(ser_task, run_id=None, dag_version_id=mock.MagicMock())
2827-
ti.refresh_from_task(ser_task, pool_override=pool_override)
2825+
ti = TI(task, run_id=None, dag_version_id=mock.MagicMock())
2826+
ti.refresh_from_task(task, pool_override=pool_override)
28282827

28292828
assert ti.queue == expected_queue
28302829

28312830
if pool_override:
28322831
assert ti.pool == pool_override
28332832
else:
2834-
assert ti.pool == sdk_task.pool
2833+
assert ti.pool == task.pool
28352834

2836-
assert ti.pool_slots == sdk_task.pool_slots
2837-
assert ti.priority_weight == ser_task.weight_rule.get_weight(ti)
2838-
assert ti.run_as_user == sdk_task.run_as_user
2839-
assert ti.max_tries == sdk_task.retries
2840-
assert ti.executor_config == sdk_task.executor_config
2835+
assert ti.pool_slots == task.pool_slots
2836+
assert ti.priority_weight == task.weight_rule.get_weight(ti)
2837+
assert ti.run_as_user == task.run_as_user
2838+
assert ti.max_tries == task.retries
2839+
assert ti.executor_config == task.executor_config
28412840
assert ti.operator == EmptyOperator.__name__
28422841

28432842
# Test that refresh_from_task does not reset ti.max_tries
2844-
expected_max_tries = sdk_task.retries + 10
2843+
expected_max_tries = task.retries + 10
28452844
ti.max_tries = expected_max_tries
2846-
ti.refresh_from_task(ser_task)
2845+
ti.refresh_from_task(task)
28472846
assert ti.max_tries == expected_max_tries
28482847

28492848

airflow-core/tests/unit/serialization/test_dag_serialization.py

Lines changed: 23 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,7 @@
8484
OperatorSerialization,
8585
_XComRef,
8686
)
87-
from airflow.task.priority_strategy import (
88-
PriorityWeightStrategy,
89-
_DownstreamPriorityWeightStrategy,
90-
airflow_priority_weight_strategies,
91-
validate_and_load_priority_weight_strategy,
92-
)
87+
from airflow.task.priority_strategy import _AbsolutePriorityWeightStrategy, _DownstreamPriorityWeightStrategy
9388
from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
9489
from airflow.timetables.simple import NullTimetable, OnceTimetable
9590
from airflow.triggers.base import StartTriggerArgs
@@ -784,7 +779,6 @@ def validate_deserialized_task(
784779
"inlets",
785780
"outlets",
786781
"task_type",
787-
"weight_rule",
788782
"_operator_name",
789783
# Type is excluded, so don't check it
790784
"_log",
@@ -818,7 +812,6 @@ def validate_deserialized_task(
818812
"operator_class",
819813
"partial_kwargs",
820814
"expand_input",
821-
"weight_rule",
822815
}
823816

824817
assert serialized_task.task_type == task.task_type
@@ -853,12 +846,6 @@ def validate_deserialized_task(
853846
if isinstance(task.params, ParamsDict) and isinstance(serialized_task.params, ParamsDict):
854847
assert serialized_task.params.dump() == task.params.dump()
855848

856-
if isinstance(task.weight_rule, PriorityWeightStrategy):
857-
assert task.weight_rule == serialized_task.weight_rule
858-
else:
859-
task_weight_strat = validate_and_load_priority_weight_strategy(task.weight_rule)
860-
assert task_weight_strat == serialized_task.weight_rule
861-
862849
if isinstance(task, MappedOperator):
863850
# MappedOperator.operator_class now stores only minimal type information
864851
# for memory efficiency (task_type and _operator_name).
@@ -1585,7 +1572,7 @@ def test_no_new_fields_added_to_base_operator(self):
15851572
"ui_fgcolor": "#000",
15861573
"wait_for_downstream": False,
15871574
"wait_for_past_depends_before_skipping": False,
1588-
"weight_rule": WeightRule.DOWNSTREAM,
1575+
"weight_rule": _DownstreamPriorityWeightStrategy(),
15891576
"multiple_outputs": False,
15901577
}, """
15911578
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
@@ -3999,6 +3986,27 @@ def test_task_callback_backward_compatibility(old_callback_name, new_callback_na
39993986
assert getattr(deserialized_task_empty, new_callback_name) is False
40003987

40013988

3989+
def test_weight_rule_absolute_serialization_deserialization():
3990+
"""Test that weight_rule can be serialized and deserialized correctly."""
3991+
from airflow.sdk import task
3992+
3993+
with DAG("test_weight_rule_dag") as dag:
3994+
3995+
@task(weight_rule=WeightRule.ABSOLUTE)
3996+
def test_task():
3997+
return "test"
3998+
3999+
test_task()
4000+
4001+
serialized_dag = DagSerialization.to_dict(dag)
4002+
assert serialized_dag["dag"]["tasks"][0]["__var"]["weight_rule"] == "absolute"
4003+
4004+
deserialized_dag = DagSerialization.from_dict(serialized_dag)
4005+
4006+
deserialized_task = deserialized_dag.task_dict["test_task"]
4007+
assert isinstance(deserialized_task.weight_rule, _AbsolutePriorityWeightStrategy)
4008+
4009+
40024010
class TestClientDefaultsGeneration:
40034011
"""Test client defaults generation functionality."""
40044012

@@ -4471,50 +4479,3 @@ def test_dag_default_args_callbacks_serialization(callbacks, expected_has_flags,
44714479

44724480
deserialized_dag = DagSerialization.deserialize_dag(serialized_dag_dict)
44734481
assert deserialized_dag.dag_id == "test_default_args_callbacks"
4474-
4475-
4476-
class RegisteredPriorityWeightStrategy(PriorityWeightStrategy):
4477-
def get_weight(self, ti):
4478-
return 99
4479-
4480-
4481-
class TestWeightRule:
4482-
def test_default(self):
4483-
sdkop = BaseOperator(task_id="should_fail")
4484-
serop = OperatorSerialization.deserialize(OperatorSerialization.serialize(sdkop))
4485-
assert serop.weight_rule == _DownstreamPriorityWeightStrategy()
4486-
4487-
@pytest.mark.parametrize(("value", "expected"), list(airflow_priority_weight_strategies.items()))
4488-
def test_builtin(self, value, expected):
4489-
sdkop = BaseOperator(task_id="should_fail", weight_rule=value)
4490-
serop = OperatorSerialization.deserialize(OperatorSerialization.serialize(sdkop))
4491-
assert serop.weight_rule == expected()
4492-
4493-
def test_custom(self):
4494-
sdkop = BaseOperator(task_id="should_fail", weight_rule=RegisteredPriorityWeightStrategy())
4495-
with mock.patch(
4496-
"airflow.serialization.serialized_objects._get_registered_priority_weight_strategy",
4497-
return_value=RegisteredPriorityWeightStrategy,
4498-
) as mock_get_registered_priority_weight_strategy:
4499-
serop = OperatorSerialization.deserialize(OperatorSerialization.serialize(sdkop))
4500-
4501-
assert serop.weight_rule == RegisteredPriorityWeightStrategy()
4502-
assert mock_get_registered_priority_weight_strategy.mock_calls == [
4503-
mock.call("unit.serialization.test_dag_serialization.RegisteredPriorityWeightStrategy"),
4504-
mock.call("unit.serialization.test_dag_serialization.RegisteredPriorityWeightStrategy"),
4505-
mock.call("unit.serialization.test_dag_serialization.RegisteredPriorityWeightStrategy"),
4506-
]
4507-
4508-
def test_invalid(self):
4509-
op = BaseOperator(task_id="should_fail", weight_rule="no rule")
4510-
with pytest.raises(ValueError, match="Unknown priority strategy"):
4511-
OperatorSerialization.serialize(op)
4512-
4513-
def test_not_registered_custom(self):
4514-
class NotRegisteredPriorityWeightStrategy(PriorityWeightStrategy):
4515-
def get_weight(self, ti):
4516-
return 99
4517-
4518-
op = BaseOperator(task_id="empty_task", weight_rule=NotRegisteredPriorityWeightStrategy())
4519-
with pytest.raises(ValueError, match="Unknown priority strategy"):
4520-
OperatorSerialization.serialize(op)

0 commit comments

Comments
 (0)