Skip to content

Commit 382eae4

Browse files
committed
coupling task type to seperate pool
1 parent 4fc299d commit 382eae4

File tree

4 files changed

+85
-47
lines changed

4 files changed

+85
-47
lines changed

openeo/extra/job_management/_manager.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def run_loop():
367367
> 0
368368

369369
or (self._worker_pool.num_pending_tasks() > 0)
370-
370+
371371
and not self._stop_thread
372372
):
373373
self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats)
@@ -644,7 +644,7 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No
644644
df_idx=i,
645645
)
646646
_log.info(f"Submitting task {task} to thread pool")
647-
self._worker_pool.submit_start_task(task)
647+
self._worker_pool.submit_task(task)
648648

649649
stats["job_queued_for_start"] += 1
650650
df.loc[i, "status"] = "queued_for_start"
@@ -690,7 +690,7 @@ def _process_threadworker_updates(
690690
:param stats: Dictionary accumulating statistic counters
691691
"""
692692
# Retrieve completed task results immediately
693-
results, start_remaining, download_remaining = worker_pool.process_all_updates(timeout=0)
693+
results, _ = worker_pool.process_all_updates(timeout=0)
694694

695695
# Collect update dicts
696696
updates: List[Dict[str, Any]] = []
@@ -746,18 +746,18 @@ def on_job_done(self, job: BatchJob, row):
746746
self._refresh_bearer_token(connection=job_con)
747747

748748
task = _JobDownloadTask(
749-
root_url=job_con.root_url,
750-
bearer_token=job_con.auth.bearer if isinstance(job_con.auth, BearerAuth) else None,
751749
job_id=job.job_id,
752750
df_idx=row.name, #this is going to be the index in the not saterted dataframe; should not be an issue as there is no db update for download task
751+
root_url=job_con.root_url,
752+
bearer_token=job_con.auth.bearer if isinstance(job_con.auth, BearerAuth) else None,
753753
download_dir=job_dir,
754754
)
755755
_log.info(f"Submitting download task {task} to download thread pool")
756756

757757
if self._worker_pool is None:
758758
self._worker_pool = _JobManagerWorkerThreadPool()
759759

760-
self._worker_pool.submit_download_task(task)
760+
self._worker_pool.submit_task(task)
761761

762762
def on_job_error(self, job: BatchJob, row):
763763
"""

openeo/extra/job_management/_thread_worker.py

Lines changed: 74 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
"""
44

55
import concurrent.futures
6-
import threading
7-
import queue
86
import logging
97
from abc import ABC, abstractmethod
108
from dataclasses import dataclass, field
@@ -103,7 +101,7 @@ def get_connection(self, retry: Union[urllib3.util.Retry, dict, bool, None] = No
103101
connection.authenticate_bearer_token(self.bearer_token)
104102
return connection
105103

106-
104+
@dataclass(frozen=True)
107105
class _JobStartTask(ConnectedTask):
108106
"""
109107
Task for starting an openEO batch job (the `POST /jobs/<job_id>/result` request).
@@ -143,18 +141,16 @@ def execute(self) -> _TaskResult:
143141
db_update={"status": "start_failed"},
144142
stats_update={"start_job error": 1},
145143
)
146-
144+
145+
@dataclass(frozen=True)
147146
class _JobDownloadTask(ConnectedTask):
148147
"""
149148
Task for downloading job results and metadata.
150149
151150
:param download_dir:
152151
Root directory where job results and metadata will be downloaded.
153-
:param download_throttle:
154-
A threading.Semaphore to limit concurrent downloads.
155152
"""
156-
download_dir: Path = field(repr=False)
157-
153+
download_dir: Path = field(default=None, repr=False)
158154

159155
def execute(self) -> _TaskResult:
160156

@@ -198,9 +194,10 @@ class _TaskThreadPool:
198194
Defaults to 2.
199195
"""
200196

201-
def __init__(self, max_workers: int = 2):
197+
def __init__(self, max_workers: int = 2, name: str = None):
202198
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
203199
self._future_task_pairs: List[Tuple[concurrent.futures.Future, Task]] = []
200+
self._name = name
204201

205202
def submit_task(self, task: Task) -> None:
206203
"""
@@ -264,36 +261,77 @@ def shutdown(self) -> None:
264261

265262

266263
class _JobManagerWorkerThreadPool:
267-
"""WRAPPER that hides two pools behind one interface"""
268-
269-
def __init__(self, max_start_workers=2, max_download_workers=10):
270-
# These are the TWO pools with their OWN _future_task_pairs
271-
self._start_pool = _TaskThreadPool(max_workers=max_start_workers)
272-
self._download_pool = _TaskThreadPool(max_workers=max_download_workers)
264+
265+
"""
266+
Generic wrapper that manages multiple thread pools with a dict.
267+
Uses task class names as pool names automatically.
268+
"""
273269

274-
def submit_start_task(self, task):
275-
# Delegate to start pool
276-
self._start_pool.submit_task(task)
270+
def __init__(self, pool_configs: Optional[Dict[str, int]] = None):
271+
"""
272+
:param pool_configs: Dict of task_class_name -> max_workers
273+
Example: {"_JobStartTask": 1, "_JobDownloadTask": 2}
274+
"""
275+
self._pools: Dict[str, _TaskThreadPool] = {}
276+
self._pool_configs = pool_configs or {}
277277

278-
def submit_download_task(self, task):
279-
# Delegate to download pool
280-
self._download_pool.submit_task(task)
278+
def _get_pool_name_for_task(self, task: Task) -> str:
279+
"""
280+
Get pool name from task class name.
281+
"""
282+
return task.__class__.__name__
281283

282-
def process_all_updates(self, timeout=0):
283-
# Get results from BOTH pools
284-
start_results, start_remaining = self._start_pool.process_futures(timeout)
285-
download_results, download_remaining = self._download_pool.process_futures(timeout)
284+
def submit_task(self, task: Task) -> None:
285+
"""
286+
Submit a task to a pool named after its class.
287+
Creates pool dynamically if it doesn't exist.
288+
"""
289+
pool_name = self._get_pool_name_for_task(task)
286290

287-
# Combine and return
288-
all_results = start_results + download_results
289-
return all_results, start_remaining, download_remaining
291+
if pool_name not in self._pools:
292+
# Create pool on-demand
293+
max_workers = self._pool_configs.get(pool_name, 1) # Default 1 worker
294+
self._pools[pool_name] = _TaskThreadPool(max_workers=max_workers, name=pool_name)
295+
_log.info(f"Created pool '{pool_name}' with {max_workers} workers")
296+
297+
self._pools[pool_name].submit_task(task)
298+
299+
def process_all_updates(self, timeout: Union[float, None] = 0) -> Tuple[List[_TaskResult], Dict[str, int]]:
300+
"""
301+
Process updates from ALL pools.
302+
Returns: (all_results, dict of remaining tasks per pool)
303+
"""
304+
all_results = []
305+
remaining_by_pool = {}
306+
307+
for pool_name, pool in self._pools.items():
308+
results, remaining = pool.process_futures(timeout)
309+
all_results.extend(results)
310+
remaining_by_pool[pool_name] = remaining
311+
312+
return all_results, remaining_by_pool
290313

291-
def num_pending_tasks(self):
292-
# Sum of BOTH pools
293-
return (self._start_pool.num_pending_tasks() +
294-
self._download_pool.num_pending_tasks())
314+
def num_pending_tasks(self, pool_name: Optional[str] = None) -> int:
315+
if pool_name:
316+
pool = self._pools.get(pool_name)
317+
return pool.num_pending_tasks() if pool else 0
318+
else:
319+
return sum(pool.num_pending_tasks() for pool in self._pools.values())
320+
321+
def shutdown(self, pool_name: Optional[str] = None) -> None:
322+
"""
323+
Shutdown pools.
324+
If pool_name is None, shuts down all pools.
325+
"""
326+
if pool_name:
327+
if pool_name in self._pools:
328+
self._pools[pool_name].shutdown()
329+
del self._pools[pool_name]
330+
else:
331+
for pool_name, pool in list(self._pools.items()):
332+
pool.shutdown()
333+
del self._pools[pool_name]
295334

296-
def shutdown(self):
297-
# Shutdown BOTH pools
298-
self._start_pool.shutdown()
299-
self._download_pool.shutdown()
335+
def list_pools(self) -> List[str]:
336+
"""List all active pool names."""
337+
return list(self._pools.keys())

tests/extra/job_management/test_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ def get_status(job_id, current_status):
729729
assert isinstance(rfc3339.parse_datetime(filled_running_start_time), datetime.datetime)
730730

731731
def test_process_threadworker_updates(self, tmp_path, caplog):
732-
pool = _JobManagerWorkerThreadPool(max_workers=2)
732+
pool = _JobManagerWorkerThreadPool()
733733
stats = collections.defaultdict(int)
734734

735735
# Submit tasks covering all cases
@@ -769,7 +769,7 @@ def test_process_threadworker_updates(self, tmp_path, caplog):
769769
assert caplog.messages == []
770770

771771
def test_process_threadworker_updates_unknown(self, tmp_path, caplog):
772-
pool = _JobManagerWorkerThreadPool(max_workers=2)
772+
pool = _JobManagerWorkerThreadPool()
773773
stats = collections.defaultdict(int)
774774

775775
pool.submit_task(DummyResultTask("j-123", df_idx=0, db_update={"status": "queued"}, stats_update={"queued": 1}))
@@ -806,7 +806,7 @@ def test_process_threadworker_updates_unknown(self, tmp_path, caplog):
806806
assert caplog.messages == [dirty_equals.IsStr(regex=".*Ignoring unknown.*indices.*4.*")]
807807

808808
def test_no_results_leaves_db_and_stats_untouched(self, tmp_path, caplog):
809-
pool = _JobManagerWorkerThreadPool(max_workers=2)
809+
pool = _JobManagerWorkerThreadPool()
810810
stats = collections.defaultdict(int)
811811

812812
df_initial = pd.DataFrame({"id": ["j-0"], "status": ["created"]})
@@ -820,7 +820,7 @@ def test_no_results_leaves_db_and_stats_untouched(self, tmp_path, caplog):
820820
assert stats == {}
821821

822822
def test_logs_on_invalid_update(self, tmp_path, caplog):
823-
pool = _JobManagerWorkerThreadPool(max_workers=2)
823+
pool = _JobManagerWorkerThreadPool()
824824
stats = collections.defaultdict(int)
825825

826826
# Malformed db_update (not a dict unpackable via **)

tests/extra/job_management/test_thread_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ class TestJobManagerWorkerThreadPool:
221221
@pytest.fixture
222222
def worker_pool(self) -> Iterator[_JobManagerWorkerThreadPool]:
223223
"""Fixture for creating and cleaning up a worker thread pool."""
224-
pool = _JobManagerWorkerThreadPool(max_workers=2)
224+
pool = _JobManagerWorkerThreadPool()
225225
yield pool
226226
pool.shutdown()
227227

0 commit comments

Comments
 (0)