Skip to content

Commit d7a87da

Browse files
committed
adding download task and creating seperate download pool
1 parent d90f06b commit d7a87da

File tree

2 files changed

+81
-10
lines changed

2 files changed

+81
-10
lines changed

openeo/extra/job_management/_manager.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from openeo.extra.job_management._thread_worker import (
3333
_JobManagerWorkerThreadPool,
3434
_JobStartTask,
35+
_JobDownloadTask
3536
)
3637
from openeo.rest import OpenEoApiError
3738
from openeo.rest.auth.auth import BearerAuth
@@ -172,6 +173,7 @@ def start_job(
172173
173174
.. versionchanged:: 0.47.0
174175
Added ``download_results`` parameter.
176+
175177
"""
176178

177179
# Expected columns in the job DB dataframes.
@@ -219,6 +221,7 @@ def __init__(
219221
)
220222
self._thread = None
221223
self._worker_pool = None
224+
self._download_pool = None
222225
# Generic cache
223226
self._cache = {}
224227

@@ -351,6 +354,7 @@ def start_job_thread(self, start_job: Callable[[], BatchJob], job_db: JobDatabas
351354

352355
self._stop_thread = False
353356
self._worker_pool = _JobManagerWorkerThreadPool()
357+
self._download_pool = _JobManagerWorkerThreadPool()
354358

355359
def run_loop():
356360
# TODO: support user-provided `stats`
@@ -388,7 +392,13 @@ def stop_job_thread(self, timeout_seconds: Optional[float] = _UNSET):
388392
389393
.. versionadded:: 0.32.0
390394
"""
391-
self._worker_pool.shutdown()
395+
if self._worker_pool is not None:
396+
self._worker_pool.shutdown()
397+
self._worker_pool = None
398+
399+
if self._download_pool is not None:
400+
self._download_pool.shutdown()
401+
self._download_pool = None
392402

393403
if self._thread is not None:
394404
self._stop_thread = True
@@ -493,6 +503,8 @@ def run_jobs(
493503
stats = collections.defaultdict(int)
494504

495505
self._worker_pool = _JobManagerWorkerThreadPool()
506+
self._download_pool = _JobManagerWorkerThreadPool()
507+
496508

497509
while (
498510
sum(
@@ -511,7 +523,7 @@ def run_jobs(
511523
stats["sleep"] += 1
512524

513525
# TODO; run post process after shutdown once more to ensure completion?
514-
self._worker_pool.shutdown()
526+
self.stop_job_thread()
515527

516528
return stats
517529

@@ -553,7 +565,11 @@ def _job_update_loop(
553565
stats["job_db persist"] += 1
554566
total_added += 1
555567

556-
self._process_threadworker_updates(self._worker_pool, job_db=job_db, stats=stats)
568+
if self._worker_pool is not None:
569+
self._process_threadworker_updates(worker_pool=self._worker_pool, job_db=job_db, stats=stats)
570+
571+
if self._download_pool is not None:
572+
self._process_threadworker_updates(worker_pool=self._download_pool, job_db=job_db, stats=stats)
557573

558574
# TODO: move this back closer to the `_track_statuses` call above, once job done/error handling is also handled in threads?
559575
for job, row in jobs_done:
@@ -565,6 +581,7 @@ def _job_update_loop(
565581
for job, row in jobs_cancel:
566582
self.on_job_cancel(job, row)
567583

584+
568585
def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = None):
569586
"""Helper method for launching jobs
570587
@@ -657,7 +674,7 @@ def _refresh_bearer_token(self, connection: Connection, *, max_age: float = 60)
657674
else:
658675
_log.warning("Failed to proactively refresh bearer token")
659676

660-
def _process_threadworker_updates(
677+
def _process_task_results(
661678
self,
662679
worker_pool: _JobManagerWorkerThreadPool,
663680
*,
@@ -723,15 +740,23 @@ def on_job_done(self, job: BatchJob, row):
723740
"""
724741
# TODO: param `row` is never accessed in this method. Remove it? Is this intended for future use?
725742
if self._download_results:
726-
job_metadata = job.describe()
727-
job_dir = self.get_job_dir(job.job_id)
728-
metadata_path = self.get_job_metadata_path(job.job_id)
729743

744+
job_dir = self.get_job_dir(job.job_id)
730745
self.ensure_job_dir_exists(job.job_id)
731-
job.get_results().download_files(target=job_dir)
732746

733-
with metadata_path.open("w", encoding="utf-8") as f:
734-
json.dump(job_metadata, f, ensure_ascii=False)
747+
# Proactively refresh bearer token (because task in thread will not be able to do that
748+
job_con = job.connection
749+
self._refresh_bearer_token(connection=job_con)
750+
751+
task = _JobDownloadTask(
752+
root_url=job_con.root_url,
753+
bearer_token=job_con.auth.bearer if isinstance(job_con.auth, BearerAuth) else None,
754+
job_id=job.job_id,
755+
df_idx=row.name, # TODO figure out correct index usage
756+
download_dir=job_dir,
757+
)
758+
_log.info(f"Submitting download task {task} to download thread pool")
759+
self._download_pool.submit_task(task)
735760

736761
def on_job_error(self, job: BatchJob, row):
737762
"""

openeo/extra/job_management/_thread_worker.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from abc import ABC, abstractmethod
88
from dataclasses import dataclass, field
99
from typing import Any, Dict, List, Optional, Tuple, Union
10+
from pathlib import Path
1011

12+
import json
1113
import urllib3.util
1214

1315
import openeo
@@ -140,6 +142,50 @@ def execute(self) -> _TaskResult:
140142
stats_update={"start_job error": 1},
141143
)
142144

145+
@dataclass(frozen=True)
146+
class _JobDownloadTask(ConnectedTask):
147+
"""
148+
Task for downloading job results and metadata.
149+
150+
:param download_dir:
151+
Root directory where job results and metadata will be downloaded.
152+
"""
153+
download_dir: Path
154+
155+
def execute(self) -> _TaskResult:
156+
"""
157+
Download job results and metadata.
158+
"""
159+
try:
160+
job = self.get_connection(retry=True).job(self.job_id)
161+
162+
# Ensure download directory exists
163+
self.download_dir.mkdir(parents=True, exist_ok=True)
164+
165+
# Download results
166+
job.get_results().download_files(target=self.download_dir)
167+
168+
# Download metadata
169+
job_metadata = job.describe()
170+
metadata_path = self.download_dir / f"job_{self.job_id}.json"
171+
with metadata_path.open("w", encoding="utf-8") as f:
172+
json.dump(job_metadata, f, ensure_ascii=False)
173+
174+
_log.info(f"Job {self.job_id!r} results downloaded successfully")
175+
return _TaskResult(
176+
job_id=self.job_id,
177+
df_idx=self.df_idx,
178+
db_update={}, #TODO consider db updates
179+
stats_update={"job download": 1},
180+
)
181+
except Exception as e:
182+
_log.error(f"Failed to download results for job {self.job_id!r}: {e!r}")
183+
return _TaskResult(
184+
job_id=self.job_id,
185+
df_idx=self.df_idx,
186+
db_update={},
187+
stats_update={"job download error": 1},
188+
)
143189

144190
class _JobManagerWorkerThreadPool:
145191
"""

0 commit comments

Comments
 (0)