3333from urllib3 .util import Retry
3434
3535from openeo import BatchJob , Connection
36+ from openeo .extra .job_management ._thread_worker import (
37+ _JobManagerWorkerThreadPool ,
38+ _JobStartTask ,
39+ )
3640from openeo .internal .processes .parse import (
3741 Parameter ,
3842 Process ,
3943 parse_remote_process_definition ,
4044)
4145from openeo .rest import OpenEoApiError
46+ from openeo .rest .auth .auth import BearerAuth
4247from openeo .util import LazyLoadCache , deep_get , repr_truncate , rfc3339
4348
4449_log = logging .getLogger (__name__ )
@@ -76,8 +81,10 @@ def exists(self) -> bool:
7681 @abc .abstractmethod
7782 def persist (self , df : pd .DataFrame ):
7883 """
79- Store job data to the database.
80- The provided dataframe may contain partial information, which is merged into the larger database.
84+ Store (now or updated) job data to the database.
85+
86+ The provided dataframe may only cover a subset of all the jobs ("rows") of the whole database,
87+ so it should be merged with the existing data (if any) instead of overwriting it completely.
8188
8289 :param df: job data to store.
8390 """
@@ -106,6 +113,18 @@ def get_by_status(self, statuses: List[str], max=None) -> pd.DataFrame:
106113 """
107114 ...
108115
116+ @abc .abstractmethod
117+ def get_by_indices (self , indices : Iterable [Union [int , str ]]) -> pd .DataFrame :
118+ """
119+ Returns a dataframe with jobs based on their (dataframe) index
120+
121+ :param indices: List of indices to include.
122+
123+ :return: DataFrame with jobs filtered by indices.
124+ """
125+ ...
126+
127+
109128def _start_job_default (row : pd .Series , connection : Connection , * args , ** kwargs ):
110129 raise NotImplementedError ("No 'start_job' callable provided" )
111130
@@ -187,6 +206,7 @@ def start_job(
187206
188207 # Expected columns in the job DB dataframes.
189208 # TODO: make this part of public API when settled?
209+ # TODO: move non official statuses to seperate column (not_started, queued_for_start)
190210 _COLUMN_REQUIREMENTS : Mapping [str , _ColumnProperties ] = {
191211 "id" : _ColumnProperties (dtype = "str" ),
192212 "backend_name" : _ColumnProperties (dtype = "str" ),
@@ -223,6 +243,7 @@ def __init__(
223243 datetime .timedelta (seconds = cancel_running_job_after ) if cancel_running_job_after is not None else None
224244 )
225245 self ._thread = None
246+ self ._worker_pool = None
226247
227248 def add_backend (
228249 self ,
@@ -359,21 +380,27 @@ def start_job_thread(self, start_job: Callable[[], BatchJob], job_db: JobDatabas
359380 _log .info (f"Resuming `run_jobs` from existing { job_db } " )
360381
361382 self ._stop_thread = False
383+ self ._worker_pool = _JobManagerWorkerThreadPool ()
362384
363385 def run_loop ():
364386
365387 # TODO: support user-provided `stats`
366388 stats = collections .defaultdict (int )
367389
368390 while (
369- sum (job_db .count_by_status (statuses = ["not_started" , "created" , "queued" , "running" ]).values ()) > 0
391+ sum (
392+ job_db .count_by_status (
393+ statuses = ["not_started" , "created" , "queued" , "queued_for_start" , "running" ]
394+ ).values ()
395+ )
396+ > 0
370397 and not self ._stop_thread
371398 ):
372- self ._job_update_loop (job_db = job_db , start_job = start_job )
399+ self ._job_update_loop (job_db = job_db , start_job = start_job , stats = stats )
373400 stats ["run_jobs loop" ] += 1
374401
402+ # Show current stats and sleep
375403 _log .info (f"Job status histogram: { job_db .count_by_status ()} . Run stats: { dict (stats )} " )
376- # Do sequence of micro-sleeps to allow for quick thread exit
377404 for _ in range (int (max (1 , self .poll_sleep ))):
378405 time .sleep (1 )
379406 if self ._stop_thread :
@@ -392,6 +419,8 @@ def stop_job_thread(self, timeout_seconds: Optional[float] = _UNSET):
392419
393420 .. versionadded:: 0.32.0
394421 """
422+ self ._worker_pool .shutdown ()
423+
395424 if self ._thread is not None :
396425 self ._stop_thread = True
397426 if timeout_seconds is _UNSET :
@@ -494,7 +523,16 @@ def run_jobs(
494523 # TODO: support user-provided `stats`
495524 stats = collections .defaultdict (int )
496525
497- while sum (job_db .count_by_status (statuses = ["not_started" , "created" , "queued" , "running" ]).values ()) > 0 :
526+ self ._worker_pool = _JobManagerWorkerThreadPool ()
527+
528+ while (
529+ sum (
530+ job_db .count_by_status (
531+ statuses = ["not_started" , "created" , "queued_for_start" , "queued" , "running" ]
532+ ).values ()
533+ )
534+ > 0
535+ ):
498536 self ._job_update_loop (job_db = job_db , start_job = start_job , stats = stats )
499537 stats ["run_jobs loop" ] += 1
500538
@@ -503,6 +541,9 @@ def run_jobs(
503541 time .sleep (self .poll_sleep )
504542 stats ["sleep" ] += 1
505543
544+ # TODO; run post process after shutdown once more to ensure completion?
545+ self ._worker_pool .shutdown ()
546+
506547 return stats
507548
508549 def _job_update_loop (
@@ -525,7 +566,8 @@ def _job_update_loop(
525566 not_started = job_db .get_by_status (statuses = ["not_started" ], max = 200 ).copy ()
526567 if len (not_started ) > 0 :
527568 # Check number of jobs running at each backend
528- running = job_db .get_by_status (statuses = ["created" , "queued" , "running" ])
569+ # TODO: should "created" be included in here? Calling this "running" is quite misleading then.
570+ running = job_db .get_by_status (statuses = ["created" , "queued" , "queued_for_start" , "running" ])
529571 stats ["job_db get_by_status" ] += 1
530572 per_backend = running .groupby ("backend_name" ).size ().to_dict ()
531573 _log .info (f"Running per backend: { per_backend } " )
@@ -542,7 +584,9 @@ def _job_update_loop(
542584 stats ["job_db persist" ] += 1
543585 total_added += 1
544586
545- # Act on jobs
587+ self ._process_threadworker_updates (self ._worker_pool , job_db = job_db , stats = stats )
588+
589+ # TODO: move this back closer to the `_track_statuses` call above, once job done/error handling is also handled in threads?
546590 for job , row in jobs_done :
547591 self .on_job_done (job , row )
548592
@@ -552,7 +596,6 @@ def _job_update_loop(
552596 for job , row in jobs_cancel :
553597 self .on_job_cancel (job , row )
554598
555-
556599 def _launch_job (self , start_job , df , i , backend_name , stats : Optional [dict ] = None ):
557600 """Helper method for launching jobs
558601
@@ -599,26 +642,91 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No
599642 df .loc [i , "start_time" ] = rfc3339 .now_utc ()
600643 if job :
601644 df .loc [i , "id" ] = job .job_id
645+ _log .info (f"Job created: { job .job_id } " )
602646 with ignore_connection_errors (context = "get status" ):
603647 status = job .status ()
604648 stats ["job get status" ] += 1
605649 df .loc [i , "status" ] = status
606650 if status == "created" :
607651 # start job if not yet done by callback
608652 try :
609- job .start ()
610- stats ["job start" ] += 1
611- df .loc [i , "status" ] = job .status ()
612- stats ["job get status" ] += 1
653+ job_con = job .connection
654+ task = _JobStartTask (
655+ root_url = job_con .root_url ,
656+ bearer_token = job_con .auth .bearer if isinstance (job_con .auth , BearerAuth ) else None ,
657+ job_id = job .job_id ,
658+ df_idx = i ,
659+ )
660+ _log .info (f"Submitting task { task } to thread pool" )
661+ self ._worker_pool .submit_task (task )
662+
663+ stats ["job_queued_for_start" ] += 1
664+ df .loc [i , "status" ] = "queued_for_start"
613665 except OpenEoApiError as e :
614- _log .error ( e )
615- df .loc [i , "status" ] = "start_failed "
616- stats ["job start error " ] += 1
666+ _log .info ( f"Failed submitting task { task } to thread pool with error: { e } " )
667+ df .loc [i , "status" ] = "queued_for_start_failed "
668+ stats ["job queued for start failed " ] += 1
617669 else :
618670 # TODO: what is this "skipping" about actually?
619671 df .loc [i , "status" ] = "skipped"
620672 stats ["start_job skipped" ] += 1
621673
674+ def _process_threadworker_updates (
675+ self ,
676+ worker_pool : _JobManagerWorkerThreadPool ,
677+ * ,
678+ job_db : JobDatabaseInterface ,
679+ stats : Dict [str , int ],
680+ ) -> None :
681+ """
682+ Fetches completed TaskResult objects from the worker pool and applies
683+ their db_update and stats_updates. Only existing DataFrame rows
684+ (matched by df_idx) are upserted via job_db.persist(). Any results
685+ targeting unknown df_idx indices are logged as errors but not persisted.
686+
687+ :param worker_pool: Thread-pool managing asynchronous Task executes
688+ :param job_db: Interface to append/upsert to the job database
689+ :param stats: Dictionary accumulating statistic counters
690+ """
691+ # Retrieve completed task results immediately
692+ results , _ = worker_pool .process_futures (timeout = 0 )
693+
694+ # Collect update dicts
695+ updates : List [Dict [str , Any ]] = []
696+ for res in results :
697+ # Process database updates
698+ if res .db_update :
699+ try :
700+ updates .append (
701+ {
702+ "id" : res .job_id ,
703+ "df_idx" : res .df_idx ,
704+ ** res .db_update ,
705+ }
706+ )
707+ except Exception as e :
708+ _log .error (f"Skipping invalid db_update { res .db_update !r} for job { res .job_id !r} : { e } " )
709+
710+ # Process stats updates
711+ if res .stats_update :
712+ try :
713+ for key , val in res .stats_update .items ():
714+ count = int (val )
715+ stats [key ] = stats .get (key , 0 ) + count
716+ except Exception as e :
717+ _log .error (f"Skipping invalid stats_update { res .stats_update !r} for job { res .job_id !r} : { e } " )
718+
719+ # No valid updates: nothing to persist
720+ if not updates :
721+ return
722+
723+ # Build update DataFrame and persist
724+ df_updates = job_db .get_by_indices (indices = set (u ["df_idx" ] for u in updates ))
725+ df_updates .update (pd .DataFrame (updates ).set_index ("df_idx" , drop = True ), overwrite = True )
726+ job_db .persist (df_updates )
727+ stats ["job_db persist" ] = stats .get ("job_db persist" , 0 ) + 1
728+
729+
622730 def on_job_done (self , job : BatchJob , row ):
623731 """
624732 Handles jobs that have finished. Can be overridden to provide custom behaviour.
@@ -674,20 +782,19 @@ def _cancel_prolonged_job(self, job: BatchJob, row):
674782 try :
675783 # Ensure running start time is valid
676784 job_running_start_time = rfc3339 .parse_datetime (row .get ("running_start_time" ), with_timezone = True )
677-
785+
678786 # Parse the current time into a datetime object with timezone info
679787 current_time = rfc3339 .parse_datetime (rfc3339 .now_utc (), with_timezone = True )
680788
681789 # Calculate the elapsed time between job start and now
682790 elapsed = current_time - job_running_start_time
683791
684792 if elapsed > self ._cancel_running_job_after :
685-
686793 _log .info (
687794 f"Cancelling long-running job { job .job_id } (after { elapsed } , running since { job_running_start_time } )"
688795 )
689796 job .stop ()
690-
797+
691798 except Exception as e :
692799 _log .error (f"Unexpected error while handling job { job .job_id } : { e } " )
693800
@@ -716,7 +823,7 @@ def _track_statuses(self, job_db: JobDatabaseInterface, stats: Optional[dict] =
716823 """
717824 stats = stats if stats is not None else collections .defaultdict (int )
718825
719- active = job_db .get_by_status (statuses = ["created" , "queued" , "running" ]).copy ()
826+ active = job_db .get_by_status (statuses = ["created" , "queued" , "queued_for_start" , " running" ]).copy ()
720827
721828 jobs_done = []
722829 jobs_error = []
@@ -738,7 +845,7 @@ def _track_statuses(self, job_db: JobDatabaseInterface, stats: Optional[dict] =
738845 f"Status of job { job_id !r} (on backend { backend_name } ) is { new_status !r} (previously { previous_status !r} )"
739846 )
740847
741- if new_status == "finished" :
848+ if previous_status != "finished" and new_status == "finished" :
742849 stats ["job finished" ] += 1
743850 jobs_done .append ((the_job , active .loc [i ]))
744851
@@ -750,7 +857,7 @@ def _track_statuses(self, job_db: JobDatabaseInterface, stats: Optional[dict] =
750857 stats ["job canceled" ] += 1
751858 jobs_cancel .append ((the_job , active .loc [i ]))
752859
753- if previous_status in {"created" , "queued" } and new_status == "running" :
860+ if previous_status in {"created" , "queued" , "queued_for_start" } and new_status == "running" :
754861 stats ["job started running" ] += 1
755862 active .loc [i , "running_start_time" ] = rfc3339 .now_utc ()
756863
@@ -874,10 +981,21 @@ def get_by_status(self, statuses, max=None) -> pd.DataFrame:
874981
875982 def _merge_into_df (self , df : pd .DataFrame ):
876983 if self ._df is not None :
984+ unknown_indices = set (df .index ).difference (df .index )
985+ if unknown_indices :
986+ _log .warning (f"Merging DataFrame with { unknown_indices = } which will be lost." )
877987 self ._df .update (df , overwrite = True )
878988 else :
879989 self ._df = df
880990
991+ def get_by_indices (self , indices : Iterable [Union [int , str ]]) -> pd .DataFrame :
992+ indices = set (indices )
993+ known = indices .intersection (self .df .index )
994+ unknown = indices .difference (self .df .index )
995+ if unknown :
996+ _log .warning (f"Ignoring unknown DataFrame indices { unknown } " )
997+ return self ._df .loc [list (known )]
998+
881999
8821000class CsvJobDatabase (FullDataFrameJobDatabase ):
8831001 """
0 commit comments