@@ -94,11 +94,11 @@ def reload_state(self) -> None:
9494 self ._prev_nodes_by_ip = copy .deepcopy (self ._nodes_by_ip )
9595 self ._nodes_by_ip = self ._get_nodes_by_ip ()
9696 logger .info ("Reloading pods" )
97- (
98- self ._pods_by_ip ,
99- self ._unschedulable_pods ,
100- self ._excluded_pods_by_ip ,
101- ) = self . _get_pods_info ()
97+ (self . _pods_by_ip , self . _unschedulable_pods , self . _excluded_pods_by_ip ,) = (
98+ self ._get_pods_info_with_label ()
99+ if self .pool_config . read_bool ( "use_labels_for_pods" , default = False )
100+ else self ._get_pods_info ()
101+ )
102102
103103 def reload_client (self ) -> None :
104104 self ._core_api = CachedCoreV1Api (self .kubeconfig_path )
@@ -403,14 +403,13 @@ def _is_node_safe_to_kill(self, node_ip: str) -> bool:
403403 return True
404404
405405 def _get_nodes_by_ip (self ) -> Mapping [str , KubernetesNode ]:
406- pool_label_selector = self .pool_config .read_string ("pool_label_key" , default = "clusterman.com/pool" )
407- pool_nodes = self ._core_api .list_node ().items
408-
409- return {
410- get_node_ip (node ): node
411- for node in pool_nodes
412- if not self .pool or node .metadata .labels .get (pool_label_selector , None ) == self .pool
413- }
406+ # TODO(CLUSTERMAN-659): Switch to using just pool_label_key once the new node labels are applied everywhere
407+ node_label_selector = self .pool_config .read_string (
408+ "node_label_key" , default = self .pool_config .read_string ("pool_label_key" , default = "clusterman.com/pool" )
409+ )
410+ label_selector = f"{ node_label_selector } ={ self .pool } "
411+ pool_nodes = self ._core_api .list_node (label_selector = label_selector ).items
412+ return {get_node_ip (node ): node for node in pool_nodes }
414413
415414 def _get_pods_info (
416415 self ,
@@ -436,6 +435,30 @@ def _get_pods_info(
436435 logger .info (f"Skipping { pod .metadata .name } pod ({ pod .status .phase } )" )
437436 return pods_by_ip , unschedulable_pods , excluded_pods_by_ip
438437
438+ def _get_pods_info_with_label (
439+ self ,
440+ ) -> Tuple [Mapping [str , List [KubernetesPod ]], List [KubernetesPod ], Mapping [str , List [KubernetesPod ]],]:
441+ pods_by_ip : Mapping [str , List [KubernetesPod ]] = defaultdict (list )
442+ unschedulable_pods : List [KubernetesPod ] = []
443+ excluded_pods_by_ip : Mapping [str , List [KubernetesPod ]] = defaultdict (list )
444+
445+ exclude_daemonset_pods = self .pool_config .read_bool (
446+ "exclude_daemonset_pods" ,
447+ default = staticconf .read_bool ("exclude_daemonset_pods" , default = False ),
448+ )
449+ label_selector = f"{ self .pool_label_key } ={ self .pool } "
450+
451+ for pod in self ._core_api .list_pod_for_all_namespaces (label_selector = label_selector ).items :
452+ if exclude_daemonset_pods and self ._pod_belongs_to_daemonset (pod ):
453+ excluded_pods_by_ip [pod .status .host_ip ].append (pod )
454+ elif pod .status .phase == "Running" or self ._is_recently_scheduled (pod ):
455+ pods_by_ip [pod .status .host_ip ].append (pod )
456+ elif self ._is_unschedulable (pod ):
457+ unschedulable_pods .append (pod )
458+ else :
459+ logger .info (f"Skipping { pod .metadata .name } pod ({ pod .status .phase } )" )
460+ return pods_by_ip , unschedulable_pods , excluded_pods_by_ip
461+
439462 def _count_batch_tasks (self , node_ip : str ) -> int :
440463 count = 0
441464 for pod in self ._pods_by_ip [node_ip ]:
0 commit comments