Skip to content

Commit 1b1e6a2

Browse files
committed
A new PrioritizedTask dataclass is added that disable comparison on the task: dict element.
* Added new test to confirm that tasks with identical priorities do not raise TypeErrors.
1 parent 36016ea commit 1b1e6a2

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

parsl/executors/high_throughput/mpi_resource_management.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pickle
55
import queue
66
import subprocess
7+
from dataclasses import dataclass, field
78
from enum import Enum
89
from typing import Dict, List, Optional
910

@@ -69,6 +70,14 @@ def __str__(self):
6970
return f"MPINodesUnavailable(requested={self.requested} available={self.available})"
7071

7172

73+
@dataclass(order=True)
74+
class PrioritizedTask:
75+
# Comparing dict will fail since they are unhashable
76+
# This dataclass limits comparison to the priority field
77+
priority: int
78+
task: Dict = field(compare=False)
79+
80+
7281
class TaskScheduler:
7382
"""Default TaskScheduler that does no taskscheduling
7483
@@ -111,7 +120,7 @@ def __init__(
111120
super().__init__(pending_task_q, pending_result_q)
112121
self.scheduler = identify_scheduler()
113122
# PriorityQueue is threadsafe
114-
self._backlog_queue: queue.PriorityQueue = queue.PriorityQueue()
123+
self._backlog_queue: queue.PriorityQueue[PrioritizedTask] = queue.PriorityQueue()
115124
self._map_tasks_to_nodes: Dict[str, List[str]] = {}
116125
self.available_nodes = get_nodes_in_batchjob(self.scheduler)
117126
self._free_node_counter = SpawnContext.Value("i", len(self.available_nodes))
@@ -169,7 +178,7 @@ def put_task(self, task_package: dict):
169178
allocated_nodes = self._get_nodes(nodes_needed)
170179
except MPINodesUnavailable:
171180
logger.info(f"Not enough resources, placing task {tid} into backlog")
172-
self._backlog_queue.put((nodes_needed, task_package))
181+
self._backlog_queue.put(PrioritizedTask(nodes_needed, task_package))
173182
return
174183
else:
175184
resource_spec["MPI_NODELIST"] = ",".join(allocated_nodes)
@@ -183,8 +192,8 @@ def put_task(self, task_package: dict):
183192
def _schedule_backlog_tasks(self):
184193
"""Attempt to schedule backlogged tasks"""
185194
try:
186-
_nodes_requested, task_package = self._backlog_queue.get(block=False)
187-
self.put_task(task_package)
195+
prioritized_task = self._backlog_queue.get(block=False)
196+
self.put_task(prioritized_task.task)
188197
except queue.Empty:
189198
return
190199
else:

parsl/tests/test_mpi_apps/test_mpi_scheduler.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,28 @@ def test_MPISched_contention():
161161
assert task_on_worker_side['task_id'] == 2
162162
_, _, _, resource_spec = unpack_res_spec_apply_message(task_on_worker_side['buffer'])
163163
assert len(resource_spec['MPI_NODELIST'].split(',')) == 8
164+
165+
166+
@pytest.mark.local
167+
def test_hashable_backlog_queue():
168+
"""Run multiple large tasks that to force entry into backlog_queue
169+
where queue.PriorityQueue expects hashability/comparability
170+
"""
171+
172+
task_q, result_q = SpawnContext.Queue(), SpawnContext.Queue()
173+
scheduler = MPITaskScheduler(task_q, result_q)
174+
175+
assert scheduler.available_nodes
176+
assert len(scheduler.available_nodes) == 8
177+
178+
assert scheduler._free_node_counter.value == 8
179+
180+
for i in range(3):
181+
mock_task_buffer = pack_res_spec_apply_message("func", "args", "kwargs",
182+
resource_specification={
183+
"num_nodes": 8,
184+
"ranks_per_node": 2
185+
})
186+
task_package = {"task_id": i, "buffer": mock_task_buffer}
187+
scheduler.put_task(task_package)
188+
assert scheduler._backlog_queue.qsize() == 2, "Expected 2 backlogged tasks"

0 commit comments

Comments
 (0)