Skip to content

Commit 4171004

Browse files
committed
Merge branch 'hv_issue719-job-manager-threaded-job-start-rebase'
2 parents 8cafa08 + 2753365 commit 4171004

File tree

8 files changed

+866
-39
lines changed

8 files changed

+866
-39
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313

1414
### Changed
1515

16+
- `MultiBackendJobManager`: starting of jobs (which can take long in some situations) is now done in side-threads to avoid blocking of the main job management thread, improving its responsiveness and allowing better overall throughput. To make this possible, a new method `get_by_indices()` was added to the `JobDatabaseInterface` API. Make sure to implement this method if you have a custom `JobDatabaseInterface` implementation that does not provide this yet. ([#719](https://github.com/Open-EO/openeo-python-client/issues/719))
17+
1618
### Removed
1719

1820
### Fixed

openeo/extra/job_management/__init__.py

Lines changed: 140 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,17 @@
3333
from urllib3.util import Retry
3434

3535
from openeo import BatchJob, Connection
36+
from openeo.extra.job_management._thread_worker import (
37+
_JobManagerWorkerThreadPool,
38+
_JobStartTask,
39+
)
3640
from openeo.internal.processes.parse import (
3741
Parameter,
3842
Process,
3943
parse_remote_process_definition,
4044
)
4145
from openeo.rest import OpenEoApiError
46+
from openeo.rest.auth.auth import BearerAuth
4247
from openeo.util import LazyLoadCache, deep_get, repr_truncate, rfc3339
4348

4449
_log = logging.getLogger(__name__)
@@ -76,8 +81,10 @@ def exists(self) -> bool:
7681
@abc.abstractmethod
7782
def persist(self, df: pd.DataFrame):
7883
"""
79-
Store job data to the database.
80-
The provided dataframe may contain partial information, which is merged into the larger database.
84+
Store (now or updated) job data to the database.
85+
86+
The provided dataframe may only cover a subset of all the jobs ("rows") of the whole database,
87+
so it should be merged with the existing data (if any) instead of overwriting it completely.
8188
8289
:param df: job data to store.
8390
"""
@@ -106,6 +113,18 @@ def get_by_status(self, statuses: List[str], max=None) -> pd.DataFrame:
106113
"""
107114
...
108115

116+
@abc.abstractmethod
117+
def get_by_indices(self, indices: Iterable[Union[int, str]]) -> pd.DataFrame:
118+
"""
119+
Returns a dataframe with jobs based on their (dataframe) index
120+
121+
:param indices: List of indices to include.
122+
123+
:return: DataFrame with jobs filtered by indices.
124+
"""
125+
...
126+
127+
109128
def _start_job_default(row: pd.Series, connection: Connection, *args, **kwargs):
110129
raise NotImplementedError("No 'start_job' callable provided")
111130

@@ -187,6 +206,7 @@ def start_job(
187206

188207
# Expected columns in the job DB dataframes.
189208
# TODO: make this part of public API when settled?
209+
# TODO: move non official statuses to seperate column (not_started, queued_for_start)
190210
_COLUMN_REQUIREMENTS: Mapping[str, _ColumnProperties] = {
191211
"id": _ColumnProperties(dtype="str"),
192212
"backend_name": _ColumnProperties(dtype="str"),
@@ -223,6 +243,7 @@ def __init__(
223243
datetime.timedelta(seconds=cancel_running_job_after) if cancel_running_job_after is not None else None
224244
)
225245
self._thread = None
246+
self._worker_pool = None
226247

227248
def add_backend(
228249
self,
@@ -359,21 +380,27 @@ def start_job_thread(self, start_job: Callable[[], BatchJob], job_db: JobDatabas
359380
_log.info(f"Resuming `run_jobs` from existing {job_db}")
360381

361382
self._stop_thread = False
383+
self._worker_pool = _JobManagerWorkerThreadPool()
362384

363385
def run_loop():
364386

365387
# TODO: support user-provided `stats`
366388
stats = collections.defaultdict(int)
367389

368390
while (
369-
sum(job_db.count_by_status(statuses=["not_started", "created", "queued", "running"]).values()) > 0
391+
sum(
392+
job_db.count_by_status(
393+
statuses=["not_started", "created", "queued", "queued_for_start", "running"]
394+
).values()
395+
)
396+
> 0
370397
and not self._stop_thread
371398
):
372-
self._job_update_loop(job_db=job_db, start_job=start_job)
399+
self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats)
373400
stats["run_jobs loop"] += 1
374401

402+
# Show current stats and sleep
375403
_log.info(f"Job status histogram: {job_db.count_by_status()}. Run stats: {dict(stats)}")
376-
# Do sequence of micro-sleeps to allow for quick thread exit
377404
for _ in range(int(max(1, self.poll_sleep))):
378405
time.sleep(1)
379406
if self._stop_thread:
@@ -392,6 +419,8 @@ def stop_job_thread(self, timeout_seconds: Optional[float] = _UNSET):
392419
393420
.. versionadded:: 0.32.0
394421
"""
422+
self._worker_pool.shutdown()
423+
395424
if self._thread is not None:
396425
self._stop_thread = True
397426
if timeout_seconds is _UNSET:
@@ -494,7 +523,16 @@ def run_jobs(
494523
# TODO: support user-provided `stats`
495524
stats = collections.defaultdict(int)
496525

497-
while sum(job_db.count_by_status(statuses=["not_started", "created", "queued", "running"]).values()) > 0:
526+
self._worker_pool = _JobManagerWorkerThreadPool()
527+
528+
while (
529+
sum(
530+
job_db.count_by_status(
531+
statuses=["not_started", "created", "queued_for_start", "queued", "running"]
532+
).values()
533+
)
534+
> 0
535+
):
498536
self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats)
499537
stats["run_jobs loop"] += 1
500538

@@ -503,6 +541,9 @@ def run_jobs(
503541
time.sleep(self.poll_sleep)
504542
stats["sleep"] += 1
505543

544+
# TODO; run post process after shutdown once more to ensure completion?
545+
self._worker_pool.shutdown()
546+
506547
return stats
507548

508549
def _job_update_loop(
@@ -525,7 +566,8 @@ def _job_update_loop(
525566
not_started = job_db.get_by_status(statuses=["not_started"], max=200).copy()
526567
if len(not_started) > 0:
527568
# Check number of jobs running at each backend
528-
running = job_db.get_by_status(statuses=["created", "queued", "running"])
569+
# TODO: should "created" be included in here? Calling this "running" is quite misleading then.
570+
running = job_db.get_by_status(statuses=["created", "queued", "queued_for_start", "running"])
529571
stats["job_db get_by_status"] += 1
530572
per_backend = running.groupby("backend_name").size().to_dict()
531573
_log.info(f"Running per backend: {per_backend}")
@@ -542,7 +584,9 @@ def _job_update_loop(
542584
stats["job_db persist"] += 1
543585
total_added += 1
544586

545-
# Act on jobs
587+
self._process_threadworker_updates(self._worker_pool, job_db=job_db, stats=stats)
588+
589+
# TODO: move this back closer to the `_track_statuses` call above, once job done/error handling is also handled in threads?
546590
for job, row in jobs_done:
547591
self.on_job_done(job, row)
548592

@@ -552,7 +596,6 @@ def _job_update_loop(
552596
for job, row in jobs_cancel:
553597
self.on_job_cancel(job, row)
554598

555-
556599
def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = None):
557600
"""Helper method for launching jobs
558601
@@ -599,26 +642,91 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No
599642
df.loc[i, "start_time"] = rfc3339.now_utc()
600643
if job:
601644
df.loc[i, "id"] = job.job_id
645+
_log.info(f"Job created: {job.job_id}")
602646
with ignore_connection_errors(context="get status"):
603647
status = job.status()
604648
stats["job get status"] += 1
605649
df.loc[i, "status"] = status
606650
if status == "created":
607651
# start job if not yet done by callback
608652
try:
609-
job.start()
610-
stats["job start"] += 1
611-
df.loc[i, "status"] = job.status()
612-
stats["job get status"] += 1
653+
job_con = job.connection
654+
task = _JobStartTask(
655+
root_url=job_con.root_url,
656+
bearer_token=job_con.auth.bearer if isinstance(job_con.auth, BearerAuth) else None,
657+
job_id=job.job_id,
658+
df_idx=i,
659+
)
660+
_log.info(f"Submitting task {task} to thread pool")
661+
self._worker_pool.submit_task(task)
662+
663+
stats["job_queued_for_start"] += 1
664+
df.loc[i, "status"] = "queued_for_start"
613665
except OpenEoApiError as e:
614-
_log.error(e)
615-
df.loc[i, "status"] = "start_failed"
616-
stats["job start error"] += 1
666+
_log.info(f"Failed submitting task {task} to thread pool with error: {e}")
667+
df.loc[i, "status"] = "queued_for_start_failed"
668+
stats["job queued for start failed"] += 1
617669
else:
618670
# TODO: what is this "skipping" about actually?
619671
df.loc[i, "status"] = "skipped"
620672
stats["start_job skipped"] += 1
621673

674+
def _process_threadworker_updates(
675+
self,
676+
worker_pool: _JobManagerWorkerThreadPool,
677+
*,
678+
job_db: JobDatabaseInterface,
679+
stats: Dict[str, int],
680+
) -> None:
681+
"""
682+
Fetches completed TaskResult objects from the worker pool and applies
683+
their db_update and stats_updates. Only existing DataFrame rows
684+
(matched by df_idx) are upserted via job_db.persist(). Any results
685+
targeting unknown df_idx indices are logged as errors but not persisted.
686+
687+
:param worker_pool: Thread-pool managing asynchronous Task executes
688+
:param job_db: Interface to append/upsert to the job database
689+
:param stats: Dictionary accumulating statistic counters
690+
"""
691+
# Retrieve completed task results immediately
692+
results, _ = worker_pool.process_futures(timeout=0)
693+
694+
# Collect update dicts
695+
updates: List[Dict[str, Any]] = []
696+
for res in results:
697+
# Process database updates
698+
if res.db_update:
699+
try:
700+
updates.append(
701+
{
702+
"id": res.job_id,
703+
"df_idx": res.df_idx,
704+
**res.db_update,
705+
}
706+
)
707+
except Exception as e:
708+
_log.error(f"Skipping invalid db_update {res.db_update!r} for job {res.job_id!r}: {e}")
709+
710+
# Process stats updates
711+
if res.stats_update:
712+
try:
713+
for key, val in res.stats_update.items():
714+
count = int(val)
715+
stats[key] = stats.get(key, 0) + count
716+
except Exception as e:
717+
_log.error(f"Skipping invalid stats_update {res.stats_update!r} for job {res.job_id!r}: {e}")
718+
719+
# No valid updates: nothing to persist
720+
if not updates:
721+
return
722+
723+
# Build update DataFrame and persist
724+
df_updates = job_db.get_by_indices(indices=set(u["df_idx"] for u in updates))
725+
df_updates.update(pd.DataFrame(updates).set_index("df_idx", drop=True), overwrite=True)
726+
job_db.persist(df_updates)
727+
stats["job_db persist"] = stats.get("job_db persist", 0) + 1
728+
729+
622730
def on_job_done(self, job: BatchJob, row):
623731
"""
624732
Handles jobs that have finished. Can be overridden to provide custom behaviour.
@@ -674,20 +782,19 @@ def _cancel_prolonged_job(self, job: BatchJob, row):
674782
try:
675783
# Ensure running start time is valid
676784
job_running_start_time = rfc3339.parse_datetime(row.get("running_start_time"), with_timezone=True)
677-
785+
678786
# Parse the current time into a datetime object with timezone info
679787
current_time = rfc3339.parse_datetime(rfc3339.now_utc(), with_timezone=True)
680788

681789
# Calculate the elapsed time between job start and now
682790
elapsed = current_time - job_running_start_time
683791

684792
if elapsed > self._cancel_running_job_after:
685-
686793
_log.info(
687794
f"Cancelling long-running job {job.job_id} (after {elapsed}, running since {job_running_start_time})"
688795
)
689796
job.stop()
690-
797+
691798
except Exception as e:
692799
_log.error(f"Unexpected error while handling job {job.job_id}: {e}")
693800

@@ -716,7 +823,7 @@ def _track_statuses(self, job_db: JobDatabaseInterface, stats: Optional[dict] =
716823
"""
717824
stats = stats if stats is not None else collections.defaultdict(int)
718825

719-
active = job_db.get_by_status(statuses=["created", "queued", "running"]).copy()
826+
active = job_db.get_by_status(statuses=["created", "queued", "queued_for_start", "running"]).copy()
720827

721828
jobs_done = []
722829
jobs_error = []
@@ -738,7 +845,7 @@ def _track_statuses(self, job_db: JobDatabaseInterface, stats: Optional[dict] =
738845
f"Status of job {job_id!r} (on backend {backend_name}) is {new_status!r} (previously {previous_status!r})"
739846
)
740847

741-
if new_status == "finished":
848+
if previous_status != "finished" and new_status == "finished":
742849
stats["job finished"] += 1
743850
jobs_done.append((the_job, active.loc[i]))
744851

@@ -750,7 +857,7 @@ def _track_statuses(self, job_db: JobDatabaseInterface, stats: Optional[dict] =
750857
stats["job canceled"] += 1
751858
jobs_cancel.append((the_job, active.loc[i]))
752859

753-
if previous_status in {"created", "queued"} and new_status == "running":
860+
if previous_status in {"created", "queued", "queued_for_start"} and new_status == "running":
754861
stats["job started running"] += 1
755862
active.loc[i, "running_start_time"] = rfc3339.now_utc()
756863

@@ -874,10 +981,21 @@ def get_by_status(self, statuses, max=None) -> pd.DataFrame:
874981

875982
def _merge_into_df(self, df: pd.DataFrame):
876983
if self._df is not None:
984+
unknown_indices = set(df.index).difference(df.index)
985+
if unknown_indices:
986+
_log.warning(f"Merging DataFrame with {unknown_indices=} which will be lost.")
877987
self._df.update(df, overwrite=True)
878988
else:
879989
self._df = df
880990

991+
def get_by_indices(self, indices: Iterable[Union[int, str]]) -> pd.DataFrame:
992+
indices = set(indices)
993+
known = indices.intersection(self.df.index)
994+
unknown = indices.difference(self.df.index)
995+
if unknown:
996+
_log.warning(f"Ignoring unknown DataFrame indices {unknown}")
997+
return self._df.loc[list(known)]
998+
881999

8821000
class CsvJobDatabase(FullDataFrameJobDatabase):
8831001
"""

0 commit comments

Comments
 (0)