Skip to content

Commit f4c5dad

Browse files
committed
threading of download
1 parent d5180c3 commit f4c5dad

File tree

1 file changed

+38
-10
lines changed

1 file changed

+38
-10
lines changed

openeo/extra/job_management/__init__.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
Optional,
2222
Union,
2323
)
24-
24+
import os
2525
import numpy
2626
import pandas as pd
2727
import requests
@@ -492,7 +492,7 @@ def run_jobs(
492492
# TODO: support user-provided `stats`
493493
stats = collections.defaultdict(int)
494494

495-
while sum(job_db.count_by_status(statuses=["not_started", "created", "queued", "running"]).values()) > 0:
495+
while sum(job_db.count_by_status(statuses=["not_started", "created", "queued", "running", "downloading"]).values()) > 0:
496496
self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats)
497497
stats["run_jobs loop"] += 1
498498

@@ -523,7 +523,7 @@ def _job_update_loop(
523523
not_started = job_db.get_by_status(statuses=["not_started"], max=200).copy()
524524
if len(not_started) > 0:
525525
# Check number of jobs running at each backend
526-
running = job_db.get_by_status(statuses=["created", "queued", "running"])
526+
running = job_db.get_by_status(statuses=["created", "queued", "running"]) #TODO I believe we need to get downloading out?
527527
stats["job_db get_by_status"] += 1
528528
per_backend = running.groupby("backend_name").size().to_dict()
529529
_log.info(f"Running per backend: {per_backend}")
@@ -606,27 +606,44 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No
606606
df.loc[i, "status"] = "skipped"
607607
stats["start_job skipped"] += 1
608608

609+
609610
def on_job_done(self, job: BatchJob, row):
610611
"""
611612
Handles jobs that have finished. Can be overridden to provide custom behaviour.
612613
613614
Default implementation downloads the results into a folder containing the title.
615+
Default implementation runs the download in a separate thread.
614616
615617
:param job: The job that has finished.
616618
:param row: DataFrame row containing the job's metadata.
617619
"""
618-
# TODO: param `row` is never accessed in this method. Remove it? Is this intended for future use?
619-
620620
job_metadata = job.describe()
621621
job_dir = self.get_job_dir(job.job_id)
622622
metadata_path = self.get_job_metadata_path(job.job_id)
623-
624623
self.ensure_job_dir_exists(job.job_id)
625-
job.get_results().download_files(target=job_dir)
626624

625+
# Start download in a separate thread
626+
downloader = Thread(target=lambda: (
627+
self._job_download(job, job_dir, row) # Invoke the download logic directly
628+
))
629+
downloader.start()
630+
631+
# Write the job metadata to a file
627632
with metadata_path.open("w", encoding="utf-8") as f:
628633
json.dump(job_metadata, f, ensure_ascii=False)
629634

635+
def _job_download(self, job, job_dir, row):
636+
"""
637+
Download the job's results and update the job status after the download completes.
638+
"""
639+
try:
640+
# Start downloading the job's results
641+
job.get_results().download_files(target=job_dir)
642+
643+
except Exception as e:
644+
# If the download fails, set the status to 'error'
645+
_log.error(f"Error downloading job {job.job_id}: {e}")
646+
630647
def on_job_error(self, job: BatchJob, row):
631648
"""
632649
Handles jobs that stopped with errors. Can be overridden to provide custom behaviour.
@@ -696,14 +713,15 @@ def ensure_job_dir_exists(self, job_id: str) -> Path:
696713
if not job_dir.exists():
697714
job_dir.mkdir(parents=True)
698715

716+
699717
def _track_statuses(self, job_db: JobDatabaseInterface, stats: Optional[dict] = None):
700718
"""
701719
Tracks status (and stats) of running jobs (in place).
702720
Optionally cancels jobs when running too long.
703721
"""
704722
stats = stats if stats is not None else collections.defaultdict(int)
705723

706-
active = job_db.get_by_status(statuses=["created", "queued", "running"]).copy()
724+
active = job_db.get_by_status(statuses=["created", "queued", "running", "downloading"]).copy()
707725
for i in active.index:
708726
job_id = active.loc[i, "id"]
709727
backend_name = active.loc[i, "backend_name"]
@@ -720,10 +738,20 @@ def _track_statuses(self, job_db: JobDatabaseInterface, stats: Optional[dict] =
720738
f"Status of job {job_id!r} (on backend {backend_name}) is {new_status!r} (previously {previous_status!r})"
721739
)
722740

723-
if new_status == "finished":
724-
stats["job finished"] += 1
741+
742+
#---------------------------------------
743+
744+
if new_status == "finished" and previous_status != "downloading":
745+
new_status = "downloading"
725746
self.on_job_done(the_job, active.loc[i])
726747

748+
if previous_status == "downloading":
749+
if self.get_job_metadata_path(job_id).exists():
750+
new_status = "finished"
751+
stats["job finished"] += 1
752+
else:
753+
new_status = "downloading"
754+
727755
if previous_status != "error" and new_status == "error":
728756
stats["job failed"] += 1
729757
self.on_job_error(the_job, active.loc[i])

0 commit comments

Comments
 (0)