Skip to content

Commit 0b1e14d

Browse files
authored
Multi-team. Verify a task uses a pool it has access to when scheduling (#61227)
1 parent d65ff01 commit 0b1e14d

File tree

3 files changed

+86
-0
lines changed

3 files changed

+86
-0
lines changed

airflow-core/src/airflow/models/pool.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def to_json(self) -> dict[str, Any]:
254254
"slots": self.slots,
255255
"description": self.description,
256256
"include_deferred": self.include_deferred,
257+
"team": self.team_name,
257258
}
258259

259260
@provide_session

airflow-core/src/airflow/ti_deps/deps/pool_slots_available_dep.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from sqlalchemy import select
2323

24+
from airflow.configuration import conf
2425
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
2526
from airflow.utils.session import provide_session
2627

@@ -41,6 +42,7 @@ def _get_dep_statuses(self, ti, session, dep_context=None):
4142
:param dep_context: the context for which this dependency should be evaluated for
4243
:return: True if there are available slots in the pool.
4344
"""
45+
from airflow.models.dag import DagModel
4446
from airflow.models.pool import Pool # To avoid a circular dependency
4547

4648
pool_name = ti.pool
@@ -53,6 +55,21 @@ def _get_dep_statuses(self, ti, session, dep_context=None):
5355
)
5456
return
5557

58+
# Check team compatibility
59+
if pool.team_name:
60+
if not conf.getboolean("core", "multi_team"):
61+
raise ValueError(
62+
"Multi-team mode is not configured in the Airflow environment but the pool the task is using belongs to a team"
63+
)
64+
65+
dag_team_name = DagModel.get_team_name(ti.dag_id, session=session)
66+
if dag_team_name != pool.team_name:
67+
yield self._failing_status(
68+
reason=f"Pool '{pool_name}' is assigned to team '{pool.team_name}' "
69+
f"but task's DAG belongs to team '{dag_team_name or 'None'}'"
70+
)
71+
return
72+
5673
open_slots = pool.open_slots(session=session)
5774
if ti.state in pool.get_occupied_states():
5875
open_slots += ti.pool_slots

airflow-core/tests/unit/ti_deps/deps/test_pool_slots_available_dep.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@
2828
from airflow.utils.state import TaskInstanceState
2929

3030
from tests_common.test_utils import db
31+
from tests_common.test_utils.config import conf_vars
3132

3233
pytestmark = pytest.mark.db_test
3334

3435

3536
class TestPoolSlotsAvailableDep:
3637
def setup_method(self):
3738
db.clear_db_pools()
39+
db.clear_db_teams()
3840
with create_session() as session:
3941
test_pool = Pool(pool="test_pool", include_deferred=False)
4042
test_includes_deferred_pool = Pool(pool="test_includes_deferred_pool", include_deferred=True)
@@ -43,6 +45,7 @@ def setup_method(self):
4345

4446
def teardown_method(self):
4547
db.clear_db_pools()
48+
db.clear_db_teams()
4649

4750
@patch("airflow.models.Pool.open_slots", return_value=0)
4851
def test_pooled_task_reached_concurrency(self, mock_open_slots):
@@ -70,3 +73,68 @@ def test_deferred_pooled_task_pass(self, mock_open_slots):
7073
def test_task_with_nonexistent_pool(self):
7174
ti = Mock(pool="nonexistent_pool", pool_slots=1)
7275
assert not PoolSlotsAvailableDep().is_met(ti=ti)
76+
77+
@conf_vars({("core", "multi_team"): "True"})
78+
def test_pool_team_mismatch(self):
79+
"""Test that a task from one team cannot use a pool assigned to another team."""
80+
from airflow.models.dag import DagModel
81+
from airflow.models.team import Team
82+
83+
with create_session() as session:
84+
# Create teams
85+
team_a = Team(name="teamA")
86+
team_b = Team(name="teamB")
87+
session.add_all([team_a, team_b])
88+
session.commit()
89+
90+
# Create a pool assigned to teamA
91+
pool_team_a = Pool(pool="pool_team_a", slots=10, include_deferred=False, team_name="teamA")
92+
session.add(pool_team_a)
93+
session.commit()
94+
95+
# Mock a task instance from a DAG belonging to teamB
96+
ti = Mock(pool="pool_team_a", pool_slots=1, dag_id="test_dag")
97+
98+
with patch.object(DagModel, "get_team_name", return_value="teamB"):
99+
assert not PoolSlotsAvailableDep().is_met(ti=ti)
100+
101+
@conf_vars({("core", "multi_team"): "True"})
102+
def test_pool_team_match(self):
103+
"""Test that a task from a team can use a pool assigned to the same team."""
104+
from airflow.models.dag import DagModel
105+
from airflow.models.team import Team
106+
107+
with create_session() as session:
108+
# Create team
109+
team_a = Team(name="teamA")
110+
session.add(team_a)
111+
session.commit()
112+
113+
# Create a pool assigned to teamA
114+
pool_team_a = Pool(pool="pool_team_a", slots=10, include_deferred=False, team_name="teamA")
115+
session.add(pool_team_a)
116+
session.commit()
117+
118+
# Mock a task instance from a DAG belonging to teamA
119+
ti = Mock(pool="pool_team_a", pool_slots=1, dag_id="test_dag", state=None)
120+
121+
with patch.object(DagModel, "get_team_name", return_value="teamA"):
122+
with patch("airflow.models.Pool.open_slots", return_value=5):
123+
assert PoolSlotsAvailableDep().is_met(ti=ti)
124+
125+
def test_pool_no_team_assignment(self):
126+
"""Test that a pool without team assignment can be used by any DAG."""
127+
from airflow.models.dag import DagModel
128+
129+
with create_session() as session:
130+
# Create a pool without team assignment
131+
pool_no_team = Pool(pool="pool_no_team", slots=10, include_deferred=False, team_name=None)
132+
session.add(pool_no_team)
133+
session.commit()
134+
135+
# Mock a task instance from any DAG
136+
ti = Mock(pool="pool_no_team", pool_slots=1, dag_id="test_dag", state=None)
137+
138+
with patch.object(DagModel, "get_team_name", return_value="teamA"):
139+
with patch("airflow.models.Pool.open_slots", return_value=5):
140+
assert PoolSlotsAvailableDep().is_met(ti=ti)

0 commit comments

Comments
 (0)