66import  subprocess 
77from  dataclasses  import  dataclass , field 
88from  enum  import  Enum 
9- from  typing  import  Dict , List , Optional 
9+ from  typing  import  Dict , List , Optional ,  Tuple 
1010
1111from  parsl .multiprocessing  import  SpawnContext 
1212from  parsl .serialize  import  pack_res_spec_apply_message , unpack_res_spec_apply_message 
@@ -76,6 +76,8 @@ class PrioritizedTask:
7676    # This dataclass limits comparison to the priority field 
7777    priority : int 
7878    task : Dict  =  field (compare = False )
79+     unpacked_task : Tuple  =  field (compare = False )
80+     nodes_needed : int  =  field (compare = False )
7981
8082
8183class  TaskScheduler :
@@ -165,29 +167,41 @@ def _return_nodes(self, nodes: List[str]) -> None:
165167        with  self ._free_node_counter .get_lock ():
166168            self ._free_node_counter .value  +=  len (nodes )  # type: ignore[attr-defined] 
167169
168-     def  put_task (self , task_package : dict ):
169-         """Schedule task if resources are available otherwise backlog the task""" 
170-         user_ns  =  locals ()
171-         user_ns .update ({"__builtins__" : __builtins__ })
172-         _f , _args , _kwargs , resource_spec  =  unpack_res_spec_apply_message (task_package ["buffer" ])
173- 
174-         nodes_needed  =  resource_spec .get ("num_nodes" )
175-         tid  =  task_package ["task_id" ]
170+     def  schedule_task (self , p_task : PrioritizedTask ):
171+         """Schedule a prioritized task if resources are available, and push to backlog 
172+         otherwise.""" 
173+         nodes_needed  =  p_task .nodes_needed 
174+         tid  =  p_task .task ["task_id" ]
176175        if  nodes_needed :
177176            try :
178177                allocated_nodes  =  self ._get_nodes (nodes_needed )
179178            except  MPINodesUnavailable :
180179                logger .info (f"Not enough resources, placing task { tid }   into backlog" )
181-                 self ._backlog_queue .put (PrioritizedTask ( nodes_needed ,  task_package ) )
180+                 self ._backlog_queue .put (p_task )
182181                return 
183182            else :
183+                 f , args , kwargs , resource_spec  =  p_task .unpacked_task 
184184                resource_spec ["MPI_NODELIST" ] =  "," .join (allocated_nodes )
185185                self ._map_tasks_to_nodes [tid ] =  allocated_nodes 
186-                 buffer  =  pack_res_spec_apply_message (_f , _args , _kwargs , resource_spec )
187-                 task_package ["buffer" ] =  buffer 
188-                 task_package ["resource_spec" ] =  resource_spec 
186+                 buffer  =  pack_res_spec_apply_message (f , args , kwargs , resource_spec )
187+                 p_task .task ["buffer" ] =  buffer 
188+                 p_task .task ["resource_spec" ] =  resource_spec 
189+ 
190+         self .pending_task_q .put (p_task .task )
191+ 
192+     def  put_task (self , task_package : dict ):
193+         """Schedule task if resources are available otherwise backlog the task""" 
194+         user_ns  =  locals ()
195+         user_ns .update ({"__builtins__" : __builtins__ })
196+         _f , _args , _kwargs , resource_spec  =  unpack_res_spec_apply_message (task_package ["buffer" ])
197+ 
198+         nodes_needed  =  resource_spec .get ("num_nodes" )
199+         prioritized_task  =  PrioritizedTask (priority = nodes_needed ,
200+                                            task = task_package ,
201+                                            unpacked_task = (_f , _args , _kwargs , resource_spec ),
202+                                            nodes_needed = nodes_needed )
189203
190-         self .pending_task_q . put ( task_package )
204+         self .schedule_task ( prioritized_task )
191205
192206    def  _schedule_backlog_tasks (self ):
193207        """Attempt to schedule backlogged tasks""" 
@@ -198,12 +212,12 @@ def _schedule_backlog_tasks(self):
198212        while  True :
199213            try :
200214                prioritized_task  =  self ._backlog_queue .get (block = False )
201-                 backlogged_tasks .append (prioritized_task . task )
215+                 backlogged_tasks .append (prioritized_task )
202216            except  queue .Empty :
203217                break 
204218
205219        for  backlogged_task  in  backlogged_tasks :
206-             self .put_task (backlogged_task )
220+             self .schedule_task (backlogged_task )
207221
208222    def  get_result (self , block : bool  =  True , timeout : Optional [float ] =  None ):
209223        """Return result and relinquish provisioned nodes""" 
0 commit comments