66import subprocess
77from dataclasses import dataclass , field
88from enum import Enum
9- from typing import Dict , List , Optional
9+ from typing import Dict , List , Optional , Tuple
1010
1111from parsl .multiprocessing import SpawnContext
1212from parsl .serialize import pack_res_spec_apply_message , unpack_res_spec_apply_message
@@ -76,6 +76,8 @@ class PrioritizedTask:
7676 # This dataclass limits comparison to the priority field
7777 priority : int
7878 task : Dict = field (compare = False )
79+ unpacked_task : Tuple = field (compare = False )
80+ nodes_needed : int = field (compare = False )
7981
8082
8183class TaskScheduler :
@@ -165,29 +167,41 @@ def _return_nodes(self, nodes: List[str]) -> None:
165167 with self ._free_node_counter .get_lock ():
166168 self ._free_node_counter .value += len (nodes ) # type: ignore[attr-defined]
167169
168- def put_task (self , task_package : dict ):
169- """Schedule task if resources are available otherwise backlog the task"""
170- user_ns = locals ()
171- user_ns .update ({"__builtins__" : __builtins__ })
172- _f , _args , _kwargs , resource_spec = unpack_res_spec_apply_message (task_package ["buffer" ])
173-
174- nodes_needed = resource_spec .get ("num_nodes" )
175- tid = task_package ["task_id" ]
170+ def schedule_task (self , p_task : PrioritizedTask ):
171+ """Schedule a prioritized task if resources are available, and push to backlog
172+ otherwise."""
173+ nodes_needed = p_task .nodes_needed
174+ tid = p_task .task ["task_id" ]
176175 if nodes_needed :
177176 try :
178177 allocated_nodes = self ._get_nodes (nodes_needed )
179178 except MPINodesUnavailable :
180179 logger .info (f"Not enough resources, placing task { tid } into backlog" )
181- self ._backlog_queue .put (PrioritizedTask ( nodes_needed , task_package ) )
180+ self ._backlog_queue .put (p_task )
182181 return
183182 else :
183+ f , args , kwargs , resource_spec = p_task .unpacked_task
184184 resource_spec ["MPI_NODELIST" ] = "," .join (allocated_nodes )
185185 self ._map_tasks_to_nodes [tid ] = allocated_nodes
186- buffer = pack_res_spec_apply_message (_f , _args , _kwargs , resource_spec )
187- task_package ["buffer" ] = buffer
188- task_package ["resource_spec" ] = resource_spec
186+ buffer = pack_res_spec_apply_message (f , args , kwargs , resource_spec )
187+ p_task .task ["buffer" ] = buffer
188+ p_task .task ["resource_spec" ] = resource_spec
189+
190+ self .pending_task_q .put (p_task .task )
191+
192+ def put_task (self , task_package : dict ):
193+ """Schedule task if resources are available otherwise backlog the task"""
194+ user_ns = locals ()
195+ user_ns .update ({"__builtins__" : __builtins__ })
196+ _f , _args , _kwargs , resource_spec = unpack_res_spec_apply_message (task_package ["buffer" ])
197+
198+ nodes_needed = resource_spec .get ("num_nodes" )
199+ prioritized_task = PrioritizedTask (priority = nodes_needed ,
200+ task = task_package ,
201+ unpacked_task = (_f , _args , _kwargs , resource_spec ),
202+ nodes_needed = nodes_needed )
189203
190- self .pending_task_q . put ( task_package )
204+ self .schedule_task ( prioritized_task )
191205
192206 def _schedule_backlog_tasks (self ):
193207 """Attempt to schedule backlogged tasks"""
@@ -198,12 +212,12 @@ def _schedule_backlog_tasks(self):
198212 while True :
199213 try :
200214 prioritized_task = self ._backlog_queue .get (block = False )
201- backlogged_tasks .append (prioritized_task . task )
215+ backlogged_tasks .append (prioritized_task )
202216 except queue .Empty :
203217 break
204218
205219 for backlogged_task in backlogged_tasks :
206- self .put_task (backlogged_task )
220+ self .schedule_task (backlogged_task )
207221
208222 def get_result (self , block : bool = True , timeout : Optional [float ] = None ):
209223 """Return result and relinquish provisioned nodes"""
0 commit comments