44import pickle
55import queue
66import subprocess
7+ from dataclasses import dataclass , field
78from enum import Enum
89from typing import Dict , List , Optional
910
@@ -69,6 +70,14 @@ def __str__(self):
6970 return f"MPINodesUnavailable(requested={ self .requested } available={ self .available } )"
7071
7172
73+ @dataclass (order = True )
74+ class PrioritizedTask :
75+ # Comparing dict will fail since they are unhashable
76+ # This dataclass limits comparison to the priority field
77+ priority : int
78+ task : Dict = field (compare = False )
79+
80+
7281class TaskScheduler :
7382 """Default TaskScheduler that does no taskscheduling
7483
@@ -111,7 +120,7 @@ def __init__(
111120 super ().__init__ (pending_task_q , pending_result_q )
112121 self .scheduler = identify_scheduler ()
113122 # PriorityQueue is threadsafe
114- self ._backlog_queue : queue .PriorityQueue = queue .PriorityQueue ()
123+ self ._backlog_queue : queue .PriorityQueue [ PrioritizedTask ] = queue .PriorityQueue ()
115124 self ._map_tasks_to_nodes : Dict [str , List [str ]] = {}
116125 self .available_nodes = get_nodes_in_batchjob (self .scheduler )
117126 self ._free_node_counter = SpawnContext .Value ("i" , len (self .available_nodes ))
@@ -169,7 +178,7 @@ def put_task(self, task_package: dict):
169178 allocated_nodes = self ._get_nodes (nodes_needed )
170179 except MPINodesUnavailable :
171180 logger .info (f"Not enough resources, placing task { tid } into backlog" )
172- self ._backlog_queue .put ((nodes_needed , task_package ))
181+ self ._backlog_queue .put (PrioritizedTask (nodes_needed , task_package ))
173182 return
174183 else :
175184 resource_spec ["MPI_NODELIST" ] = "," .join (allocated_nodes )
@@ -183,8 +192,8 @@ def put_task(self, task_package: dict):
183192 def _schedule_backlog_tasks (self ):
184193 """Attempt to schedule backlogged tasks"""
185194 try :
186- _nodes_requested , task_package = self ._backlog_queue .get (block = False )
187- self .put_task (task_package )
195+ prioritized_task = self ._backlog_queue .get (block = False )
196+ self .put_task (prioritized_task . task )
188197 except queue .Empty :
189198 return
190199 else :
0 commit comments