44import  pickle 
55import  queue 
66import  subprocess 
7+ from  dataclasses  import  dataclass , field 
78from  enum  import  Enum 
89from  typing  import  Dict , List , Optional 
910
@@ -69,6 +70,14 @@ def __str__(self):
6970        return  f"MPINodesUnavailable(requested={ self .requested }   available={ self .available }  )" 
7071
7172
73+ @dataclass (order = True ) 
74+ class  PrioritizedTask :
75+     # Comparing dict will fail since they are unhashable 
76+     # This dataclass limits comparison to the priority field 
77+     priority : int 
78+     task : Dict  =  field (compare = False )
79+ 
80+ 
7281class  TaskScheduler :
7382    """Default TaskScheduler that does no taskscheduling 
7483
@@ -111,7 +120,7 @@ def __init__(
111120        super ().__init__ (pending_task_q , pending_result_q )
112121        self .scheduler  =  identify_scheduler ()
113122        # PriorityQueue is threadsafe 
114-         self ._backlog_queue : queue .PriorityQueue  =  queue .PriorityQueue ()
123+         self ._backlog_queue : queue .PriorityQueue [ PrioritizedTask ]  =  queue .PriorityQueue ()
115124        self ._map_tasks_to_nodes : Dict [str , List [str ]] =  {}
116125        self .available_nodes  =  get_nodes_in_batchjob (self .scheduler )
117126        self ._free_node_counter  =  SpawnContext .Value ("i" , len (self .available_nodes ))
@@ -169,7 +178,8 @@ def put_task(self, task_package: dict):
169178                allocated_nodes  =  self ._get_nodes (nodes_needed )
170179            except  MPINodesUnavailable :
171180                logger .info (f"Not enough resources, placing task { tid }   into backlog" )
172-                 self ._backlog_queue .put ((nodes_needed , task_package ))
181+                 # Negate the priority element so that larger tasks are prioritized 
182+                 self ._backlog_queue .put (PrioritizedTask (- 1  *  nodes_needed , task_package ))
173183                return 
174184            else :
175185                resource_spec ["MPI_NODELIST" ] =  "," .join (allocated_nodes )
@@ -182,14 +192,16 @@ def put_task(self, task_package: dict):
182192
183193    def  _schedule_backlog_tasks (self ):
184194        """Attempt to schedule backlogged tasks""" 
185-         try :
186-             _nodes_requested , task_package  =  self ._backlog_queue .get (block = False )
187-             self .put_task (task_package )
188-         except  queue .Empty :
189-             return 
190-         else :
191-             # Keep attempting to schedule tasks till we are out of resources 
192-             self ._schedule_backlog_tasks ()
195+ 
196+         # Separate fetching tasks from the _backlog_queue and scheduling them 
197+         # since tasks that failed to schedule will be pushed to the _backlog_queue 
198+         backlogged_tasks  =  []
199+         while  not  self ._backlog_queue .empty ():
200+             prioritized_task  =  self ._backlog_queue .get (block = False )
201+             backlogged_tasks .append (prioritized_task .task )
202+ 
203+         for  backlogged_task  in  backlogged_tasks :
204+             self .put_task (backlogged_task )
193205
194206    def  get_result (self , block : bool  =  True , timeout : Optional [float ] =  None ):
195207        """Return result and relinquish provisioned nodes""" 
0 commit comments