2121 Optional ,
2222 Union ,
2323)
24-
24+ import os
2525import numpy
2626import pandas as pd
2727import requests
@@ -492,7 +492,7 @@ def run_jobs(
492492 # TODO: support user-provided `stats`
493493 stats = collections .defaultdict (int )
494494
495- while sum (job_db .count_by_status (statuses = ["not_started" , "created" , "queued" , "running" ]).values ()) > 0 :
495+ while sum (job_db .count_by_status (statuses = ["not_started" , "created" , "queued" , "running" , "downloading" ]).values ()) > 0 :
496496 self ._job_update_loop (job_db = job_db , start_job = start_job , stats = stats )
497497 stats ["run_jobs loop" ] += 1
498498
@@ -523,7 +523,7 @@ def _job_update_loop(
523523 not_started = job_db .get_by_status (statuses = ["not_started" ], max = 200 ).copy ()
524524 if len (not_started ) > 0 :
525525 # Check number of jobs running at each backend
526- running = job_db .get_by_status (statuses = ["created" , "queued" , "running" ])
526+ running = job_db .get_by_status (statuses = ["created" , "queued" , "running" ]) #TODO I believe we need to get downloading out?
527527 stats ["job_db get_by_status" ] += 1
528528 per_backend = running .groupby ("backend_name" ).size ().to_dict ()
529529 _log .info (f"Running per backend: { per_backend } " )
@@ -606,27 +606,44 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No
606606 df .loc [i , "status" ] = "skipped"
607607 stats ["start_job skipped" ] += 1
608608
609+
609610 def on_job_done (self , job : BatchJob , row ):
610611 """
611612 Handles jobs that have finished. Can be overridden to provide custom behaviour.
612613
613614 Default implementation downloads the results into a folder containing the title.
615+ Default implementation runs the download in a separate thread.
614616
615617 :param job: The job that has finished.
616618 :param row: DataFrame row containing the job's metadata.
617619 """
618- # TODO: param `row` is never accessed in this method. Remove it? Is this intended for future use?
619-
620620 job_metadata = job .describe ()
621621 job_dir = self .get_job_dir (job .job_id )
622622 metadata_path = self .get_job_metadata_path (job .job_id )
623-
624623 self .ensure_job_dir_exists (job .job_id )
625- job .get_results ().download_files (target = job_dir )
626624
625+ # Start download in a separate thread
626+ downloader = Thread (target = lambda : (
627+ self ._job_download (job , job_dir , row ) # Invoke the download logic directly
628+ ))
629+ downloader .start ()
630+
631+ # Write the job metadata to a file
627632 with metadata_path .open ("w" , encoding = "utf-8" ) as f :
628633 json .dump (job_metadata , f , ensure_ascii = False )
629634
635+ def _job_download (self , job , job_dir , row ):
636+ """
637+ Download the job's results and update the job status after the download completes.
638+ """
639+ try :
640+ # Start downloading the job's results
641+ job .get_results ().download_files (target = job_dir )
642+
643+ except Exception as e :
644+ # If the download fails, set the status to 'error'
645+ _log .error (f"Error downloading job { job .job_id } : { e } " )
646+
630647 def on_job_error (self , job : BatchJob , row ):
631648 """
632649 Handles jobs that stopped with errors. Can be overridden to provide custom behaviour.
@@ -696,14 +713,15 @@ def ensure_job_dir_exists(self, job_id: str) -> Path:
696713 if not job_dir .exists ():
697714 job_dir .mkdir (parents = True )
698715
716+
699717 def _track_statuses (self , job_db : JobDatabaseInterface , stats : Optional [dict ] = None ):
700718 """
701719 Tracks status (and stats) of running jobs (in place).
702720 Optionally cancels jobs when running too long.
703721 """
704722 stats = stats if stats is not None else collections .defaultdict (int )
705723
706- active = job_db .get_by_status (statuses = ["created" , "queued" , "running" ]).copy ()
724+ active = job_db .get_by_status (statuses = ["created" , "queued" , "running" , "downloading" ]).copy ()
707725 for i in active .index :
708726 job_id = active .loc [i , "id" ]
709727 backend_name = active .loc [i , "backend_name" ]
@@ -720,10 +738,20 @@ def _track_statuses(self, job_db: JobDatabaseInterface, stats: Optional[dict] =
720738 f"Status of job { job_id !r} (on backend { backend_name } ) is { new_status !r} (previously { previous_status !r} )"
721739 )
722740
723- if new_status == "finished" :
724- stats ["job finished" ] += 1
741+
742+ #---------------------------------------
743+
744+ if new_status == "finished" and previous_status != "downloading" :
745+ new_status = "downloading"
725746 self .on_job_done (the_job , active .loc [i ])
726747
748+ if previous_status == "downloading" :
749+ if self .get_job_metadata_path (job_id ).exists ():
750+ new_status = "finished"
751+ stats ["job finished" ] += 1
752+ else :
753+ new_status = "downloading"
754+
727755 if previous_status != "error" and new_status == "error" :
728756 stats ["job failed" ] += 1
729757 self .on_job_error (the_job , active .loc [i ])
0 commit comments