2828from airflow .utils .state import TaskInstanceState
2929
3030from tests_common .test_utils import db
31+ from tests_common .test_utils .config import conf_vars
3132
3233pytestmark = pytest .mark .db_test
3334
3435
3536class TestPoolSlotsAvailableDep :
3637 def setup_method (self ):
3738 db .clear_db_pools ()
39+ db .clear_db_teams ()
3840 with create_session () as session :
3941 test_pool = Pool (pool = "test_pool" , include_deferred = False )
4042 test_includes_deferred_pool = Pool (pool = "test_includes_deferred_pool" , include_deferred = True )
@@ -43,6 +45,7 @@ def setup_method(self):
4345
4446 def teardown_method (self ):
4547 db .clear_db_pools ()
48+ db .clear_db_teams ()
4649
4750 @patch ("airflow.models.Pool.open_slots" , return_value = 0 )
4851 def test_pooled_task_reached_concurrency (self , mock_open_slots ):
@@ -70,3 +73,68 @@ def test_deferred_pooled_task_pass(self, mock_open_slots):
7073 def test_task_with_nonexistent_pool (self ):
7174 ti = Mock (pool = "nonexistent_pool" , pool_slots = 1 )
7275 assert not PoolSlotsAvailableDep ().is_met (ti = ti )
76+
77+ @conf_vars ({("core" , "multi_team" ): "True" })
78+ def test_pool_team_mismatch (self ):
79+ """Test that a task from one team cannot use a pool assigned to another team."""
80+ from airflow .models .dag import DagModel
81+ from airflow .models .team import Team
82+
83+ with create_session () as session :
84+ # Create teams
85+ team_a = Team (name = "teamA" )
86+ team_b = Team (name = "teamB" )
87+ session .add_all ([team_a , team_b ])
88+ session .commit ()
89+
90+ # Create a pool assigned to teamA
91+ pool_team_a = Pool (pool = "pool_team_a" , slots = 10 , include_deferred = False , team_name = "teamA" )
92+ session .add (pool_team_a )
93+ session .commit ()
94+
95+ # Mock a task instance from a DAG belonging to teamB
96+ ti = Mock (pool = "pool_team_a" , pool_slots = 1 , dag_id = "test_dag" )
97+
98+ with patch .object (DagModel , "get_team_name" , return_value = "teamB" ):
99+ assert not PoolSlotsAvailableDep ().is_met (ti = ti )
100+
101+ @conf_vars ({("core" , "multi_team" ): "True" })
102+ def test_pool_team_match (self ):
103+ """Test that a task from a team can use a pool assigned to the same team."""
104+ from airflow .models .dag import DagModel
105+ from airflow .models .team import Team
106+
107+ with create_session () as session :
108+ # Create team
109+ team_a = Team (name = "teamA" )
110+ session .add (team_a )
111+ session .commit ()
112+
113+ # Create a pool assigned to teamA
114+ pool_team_a = Pool (pool = "pool_team_a" , slots = 10 , include_deferred = False , team_name = "teamA" )
115+ session .add (pool_team_a )
116+ session .commit ()
117+
118+ # Mock a task instance from a DAG belonging to teamA
119+ ti = Mock (pool = "pool_team_a" , pool_slots = 1 , dag_id = "test_dag" , state = None )
120+
121+ with patch .object (DagModel , "get_team_name" , return_value = "teamA" ):
122+ with patch ("airflow.models.Pool.open_slots" , return_value = 5 ):
123+ assert PoolSlotsAvailableDep ().is_met (ti = ti )
124+
125+ def test_pool_no_team_assignment (self ):
126+ """Test that a pool without team assignment can be used by any DAG."""
127+ from airflow .models .dag import DagModel
128+
129+ with create_session () as session :
130+ # Create a pool without team assignment
131+ pool_no_team = Pool (pool = "pool_no_team" , slots = 10 , include_deferred = False , team_name = None )
132+ session .add (pool_no_team )
133+ session .commit ()
134+
135+ # Mock a task instance from any DAG
136+ ti = Mock (pool = "pool_no_team" , pool_slots = 1 , dag_id = "test_dag" , state = None )
137+
138+ with patch .object (DagModel , "get_team_name" , return_value = "teamA" ):
139+ with patch ("airflow.models.Pool.open_slots" , return_value = 5 ):
140+ assert PoolSlotsAvailableDep ().is_met (ti = ti )
0 commit comments