3232from openeo .extra .job_management ._thread_worker import (
3333 _JobManagerWorkerThreadPool ,
3434 _JobStartTask ,
35+ _JobDownloadTask
3536)
3637from openeo .rest import OpenEoApiError
3738from openeo .rest .auth .auth import BearerAuth
@@ -172,6 +173,7 @@ def start_job(
172173
173174 .. versionchanged:: 0.47.0
174175 Added ``download_results`` parameter.
176+
175177 """
176178
177179 # Expected columns in the job DB dataframes.
@@ -219,6 +221,7 @@ def __init__(
219221 )
220222 self ._thread = None
221223 self ._worker_pool = None
224+ self ._download_pool = None
222225 # Generic cache
223226 self ._cache = {}
224227
@@ -351,6 +354,7 @@ def start_job_thread(self, start_job: Callable[[], BatchJob], job_db: JobDatabas
351354
352355 self ._stop_thread = False
353356 self ._worker_pool = _JobManagerWorkerThreadPool ()
357+ self ._download_pool = _JobManagerWorkerThreadPool ()
354358
355359 def run_loop ():
356360 # TODO: support user-provided `stats`
@@ -388,7 +392,13 @@ def stop_job_thread(self, timeout_seconds: Optional[float] = _UNSET):
388392
389393 .. versionadded:: 0.32.0
390394 """
391- self ._worker_pool .shutdown ()
395+ if self ._worker_pool is not None :
396+ self ._worker_pool .shutdown ()
397+ self ._worker_pool = None
398+
399+ if self ._download_pool is not None :
400+ self ._download_pool .shutdown ()
401+ self ._download_pool = None
392402
393403 if self ._thread is not None :
394404 self ._stop_thread = True
@@ -493,6 +503,8 @@ def run_jobs(
493503 stats = collections .defaultdict (int )
494504
495505 self ._worker_pool = _JobManagerWorkerThreadPool ()
506+ self ._download_pool = _JobManagerWorkerThreadPool ()
507+
496508
497509 while (
498510 sum (
@@ -511,7 +523,7 @@ def run_jobs(
511523 stats ["sleep" ] += 1
512524
513525 # TODO; run post process after shutdown once more to ensure completion?
514- self ._worker_pool . shutdown ()
526+ self .stop_job_thread ()
515527
516528 return stats
517529
@@ -553,7 +565,11 @@ def _job_update_loop(
553565 stats ["job_db persist" ] += 1
554566 total_added += 1
555567
556- self ._process_threadworker_updates (self ._worker_pool , job_db = job_db , stats = stats )
568+ if self ._worker_pool is not None :
569+ self ._process_threadworker_updates (worker_pool = self ._worker_pool , job_db = job_db , stats = stats )
570+
571+ if self ._download_pool is not None :
572+ self ._process_threadworker_updates (worker_pool = self ._download_pool , job_db = job_db , stats = stats )
557573
558574 # TODO: move this back closer to the `_track_statuses` call above, once job done/error handling is also handled in threads?
559575 for job , row in jobs_done :
@@ -565,6 +581,7 @@ def _job_update_loop(
565581 for job , row in jobs_cancel :
566582 self .on_job_cancel (job , row )
567583
584+
568585 def _launch_job (self , start_job , df , i , backend_name , stats : Optional [dict ] = None ):
569586 """Helper method for launching jobs
570587
@@ -657,7 +674,7 @@ def _refresh_bearer_token(self, connection: Connection, *, max_age: float = 60)
657674 else :
658675 _log .warning ("Failed to proactively refresh bearer token" )
659676
660- def _process_threadworker_updates (
677+ def _process_task_results (
661678 self ,
662679 worker_pool : _JobManagerWorkerThreadPool ,
663680 * ,
@@ -723,15 +740,23 @@ def on_job_done(self, job: BatchJob, row):
723740 """
724741 # TODO: param `row` is never accessed in this method. Remove it? Is this intended for future use?
725742 if self ._download_results :
726- job_metadata = job .describe ()
727- job_dir = self .get_job_dir (job .job_id )
728- metadata_path = self .get_job_metadata_path (job .job_id )
729743
744+ job_dir = self .get_job_dir (job .job_id )
730745 self .ensure_job_dir_exists (job .job_id )
731- job .get_results ().download_files (target = job_dir )
732746
733- with metadata_path .open ("w" , encoding = "utf-8" ) as f :
734- json .dump (job_metadata , f , ensure_ascii = False )
747+ # Proactively refresh bearer token (because task in thread will not be able to do that
748+ job_con = job .connection
749+ self ._refresh_bearer_token (connection = job_con )
750+
751+ task = _JobDownloadTask (
752+ root_url = job_con .root_url ,
753+ bearer_token = job_con .auth .bearer if isinstance (job_con .auth , BearerAuth ) else None ,
754+ job_id = job .job_id ,
755+ df_idx = row .name , # TODO figure out correct index usage
756+ download_dir = job_dir ,
757+ )
758+ _log .info (f"Submitting download task { task } to download thread pool" )
759+ self ._download_pool .submit_task (task )
735760
736761 def on_job_error (self , job : BatchJob , row ):
737762 """
0 commit comments