|
3 | 3 | """ |
4 | 4 |
|
5 | 5 | import concurrent.futures |
6 | | -import threading |
7 | | -import queue |
8 | 6 | import logging |
9 | 7 | from abc import ABC, abstractmethod |
10 | 8 | from dataclasses import dataclass, field |
@@ -103,7 +101,7 @@ def get_connection(self, retry: Union[urllib3.util.Retry, dict, bool, None] = No |
103 | 101 | connection.authenticate_bearer_token(self.bearer_token) |
104 | 102 | return connection |
105 | 103 |
|
106 | | - |
| 104 | +@dataclass(frozen=True) |
107 | 105 | class _JobStartTask(ConnectedTask): |
108 | 106 | """ |
109 | 107 | Task for starting an openEO batch job (the `POST /jobs/<job_id>/result` request). |
@@ -143,18 +141,16 @@ def execute(self) -> _TaskResult: |
143 | 141 | db_update={"status": "start_failed"}, |
144 | 142 | stats_update={"start_job error": 1}, |
145 | 143 | ) |
146 | | - |
| 144 | + |
| 145 | +@dataclass(frozen=True) |
147 | 146 | class _JobDownloadTask(ConnectedTask): |
148 | 147 | """ |
149 | 148 | Task for downloading job results and metadata. |
150 | 149 |
|
151 | 150 | :param download_dir: |
152 | 151 | Root directory where job results and metadata will be downloaded. |
153 | | - :param download_throttle: |
154 | | - A threading.Semaphore to limit concurrent downloads. |
155 | 152 | """ |
156 | | - download_dir: Path = field(repr=False) |
157 | | - |
| 153 | + download_dir: Path = field(default=None, repr=False) |
158 | 154 |
|
159 | 155 | def execute(self) -> _TaskResult: |
160 | 156 |
|
@@ -198,9 +194,10 @@ class _TaskThreadPool: |
198 | 194 | Defaults to 2. |
199 | 195 | """ |
200 | 196 |
|
201 | | - def __init__(self, max_workers: int = 2): |
| 197 | + def __init__(self, max_workers: int = 2, name: str = None): |
202 | 198 | self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) |
203 | 199 | self._future_task_pairs: List[Tuple[concurrent.futures.Future, Task]] = [] |
| 200 | + self._name = name |
204 | 201 |
|
205 | 202 | def submit_task(self, task: Task) -> None: |
206 | 203 | """ |
@@ -264,36 +261,77 @@ def shutdown(self) -> None: |
264 | 261 |
|
265 | 262 |
|
266 | 263 | class _JobManagerWorkerThreadPool: |
267 | | - """WRAPPER that hides two pools behind one interface""" |
268 | | - |
269 | | - def __init__(self, max_start_workers=2, max_download_workers=10): |
270 | | - # These are the TWO pools with their OWN _future_task_pairs |
271 | | - self._start_pool = _TaskThreadPool(max_workers=max_start_workers) |
272 | | - self._download_pool = _TaskThreadPool(max_workers=max_download_workers) |
| 264 | + |
| 265 | + """ |
| 266 | + Generic wrapper that manages multiple thread pools with a dict. |
| 267 | + Uses task class names as pool names automatically. |
| 268 | + """ |
273 | 269 |
|
274 | | - def submit_start_task(self, task): |
275 | | - # Delegate to start pool |
276 | | - self._start_pool.submit_task(task) |
| 270 | + def __init__(self, pool_configs: Optional[Dict[str, int]] = None): |
| 271 | + """ |
| 272 | + :param pool_configs: Dict of task_class_name -> max_workers |
| 273 | + Example: {"_JobStartTask": 1, "_JobDownloadTask": 2} |
| 274 | + """ |
| 275 | + self._pools: Dict[str, _TaskThreadPool] = {} |
| 276 | + self._pool_configs = pool_configs or {} |
277 | 277 |
|
278 | | - def submit_download_task(self, task): |
279 | | - # Delegate to download pool |
280 | | - self._download_pool.submit_task(task) |
| 278 | + def _get_pool_name_for_task(self, task: Task) -> str: |
| 279 | + """ |
| 280 | + Get pool name from task class name. |
| 281 | + """ |
| 282 | + return task.__class__.__name__ |
281 | 283 |
|
282 | | - def process_all_updates(self, timeout=0): |
283 | | - # Get results from BOTH pools |
284 | | - start_results, start_remaining = self._start_pool.process_futures(timeout) |
285 | | - download_results, download_remaining = self._download_pool.process_futures(timeout) |
| 284 | + def submit_task(self, task: Task) -> None: |
| 285 | + """ |
| 286 | + Submit a task to a pool named after its class. |
| 287 | + Creates pool dynamically if it doesn't exist. |
| 288 | + """ |
| 289 | + pool_name = self._get_pool_name_for_task(task) |
286 | 290 |
|
287 | | - # Combine and return |
288 | | - all_results = start_results + download_results |
289 | | - return all_results, start_remaining, download_remaining |
| 291 | + if pool_name not in self._pools: |
| 292 | + # Create pool on-demand |
| 293 | + max_workers = self._pool_configs.get(pool_name, 1) # Default 1 worker |
| 294 | + self._pools[pool_name] = _TaskThreadPool(max_workers=max_workers, name=pool_name) |
| 295 | + _log.info(f"Created pool '{pool_name}' with {max_workers} workers") |
| 296 | + |
| 297 | + self._pools[pool_name].submit_task(task) |
| 298 | + |
| 299 | + def process_all_updates(self, timeout: Union[float, None] = 0) -> Tuple[List[_TaskResult], Dict[str, int]]: |
| 300 | + """ |
| 301 | + Process updates from ALL pools. |
| 302 | + Returns: (all_results, dict of remaining tasks per pool) |
| 303 | + """ |
| 304 | + all_results = [] |
| 305 | + remaining_by_pool = {} |
| 306 | + |
| 307 | + for pool_name, pool in self._pools.items(): |
| 308 | + results, remaining = pool.process_futures(timeout) |
| 309 | + all_results.extend(results) |
| 310 | + remaining_by_pool[pool_name] = remaining |
| 311 | + |
| 312 | + return all_results, remaining_by_pool |
290 | 313 |
|
291 | | - def num_pending_tasks(self): |
292 | | - # Sum of BOTH pools |
293 | | - return (self._start_pool.num_pending_tasks() + |
294 | | - self._download_pool.num_pending_tasks()) |
| 314 | + def num_pending_tasks(self, pool_name: Optional[str] = None) -> int: |
| 315 | + if pool_name: |
| 316 | + pool = self._pools.get(pool_name) |
| 317 | + return pool.num_pending_tasks() if pool else 0 |
| 318 | + else: |
| 319 | + return sum(pool.num_pending_tasks() for pool in self._pools.values()) |
| 320 | + |
| 321 | + def shutdown(self, pool_name: Optional[str] = None) -> None: |
| 322 | + """ |
| 323 | + Shutdown pools. |
| 324 | + If pool_name is None, shuts down all pools. |
| 325 | + """ |
| 326 | + if pool_name: |
| 327 | + if pool_name in self._pools: |
| 328 | + self._pools[pool_name].shutdown() |
| 329 | + del self._pools[pool_name] |
| 330 | + else: |
| 331 | + for pool_name, pool in list(self._pools.items()): |
| 332 | + pool.shutdown() |
| 333 | + del self._pools[pool_name] |
295 | 334 |
|
296 | | - def shutdown(self): |
297 | | - # Shutdown BOTH pools |
298 | | - self._start_pool.shutdown() |
299 | | - self._download_pool.shutdown() |
| 335 | + def list_pools(self) -> List[str]: |
| 336 | + """List all active pool names.""" |
| 337 | + return list(self._pools.keys()) |
0 commit comments