Skip to content

Commit 434854f

Browse files
committed
Introduce JobDatabaseInterface.get_by_indices
- make sure `persist` gets complete rows - implement in both `FullDataFrameJobDatabase` and `STACAPIJobDatabase` refs: #719, #736, #793
1 parent ed28aaa commit 434854f

File tree

3 files changed

+99
-24
lines changed

3 files changed

+99
-24
lines changed

openeo/extra/job_management/__init__.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,10 @@ def exists(self) -> bool:
8181
@abc.abstractmethod
8282
def persist(self, df: pd.DataFrame):
8383
"""
84-
Store job data to the database.
85-
The provided dataframe may contain partial information, which is merged into the larger database.
84+
Store (now or updated) job data to the database.
85+
86+
The provided dataframe may only cover a subset of all the jobs ("rows") of the whole database,
87+
so it should be merged with the existing data (if any) instead of overwriting it completely.
8688
8789
:param df: job data to store.
8890
"""
@@ -111,6 +113,17 @@ def get_by_status(self, statuses: List[str], max=None) -> pd.DataFrame:
111113
"""
112114
...
113115

116+
@abc.abstractmethod
117+
def get_by_indices(self, indices: Iterable[Union[int, str]]) -> pd.DataFrame:
118+
"""
119+
Returns a dataframe with jobs based on their (dataframe) index
120+
121+
:param indices: List of indices to include.
122+
123+
:return: DataFrame with jobs filtered by indices.
124+
"""
125+
...
126+
114127

115128
def _start_job_default(row: pd.Series, connection: Connection, *args, **kwargs):
116129
raise NotImplementedError("No 'start_job' callable provided")
@@ -707,9 +720,9 @@ def _process_threadworker_updates(
707720
if not updates:
708721
return
709722

710-
# Build DataFrame of updates indexed by df_idx
711-
df_updates = pd.DataFrame(updates).set_index("df_idx", drop=True)
712-
723+
# Build update DataFrame and persist
724+
df_updates = job_db.get_by_indices(indices=set(u["df_idx"] for u in updates))
725+
df_updates.update(pd.DataFrame(updates).set_index("df_idx", drop=True), overwrite=True)
713726
job_db.persist(df_updates)
714727
stats["job_db persist"] = stats.get("job_db persist", 0) + 1
715728

@@ -968,10 +981,21 @@ def get_by_status(self, statuses, max=None) -> pd.DataFrame:
968981

969982
def _merge_into_df(self, df: pd.DataFrame):
970983
if self._df is not None:
984+
unknown_indices = set(df.index).difference(df.index)
985+
if unknown_indices:
986+
_log.warning(f"Merging DataFrame with {unknown_indices=} which will be lost.")
971987
self._df.update(df, overwrite=True)
972988
else:
973989
self._df = df
974990

991+
def get_by_indices(self, indices: Iterable[Union[int, str]]) -> pd.DataFrame:
992+
indices = set(indices)
993+
known = indices.intersection(self.df.index)
994+
unknown = indices.difference(self.df.index)
995+
if unknown:
996+
_log.warning(f"Ignoring unknown DataFrame indices {unknown}")
997+
return self._df.loc[list(known)]
998+
975999

9761000
class CsvJobDatabase(FullDataFrameJobDatabase):
9771001
"""

openeo/extra/job_management/stac_job_db.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import concurrent.futures
22
import datetime
33
import logging
4-
from typing import Iterable, List, Optional
4+
from typing import Iterable, List, Optional, Union
55

66
import geopandas as gpd
77
import numpy as np
@@ -165,6 +165,12 @@ def count_by_status(self, statuses: Iterable[str] = ()) -> dict:
165165
else:
166166
return items["status"].value_counts().to_dict()
167167

168+
def _search_result_to_df(self, search_result: pystac_client.ItemSearch) -> pd.DataFrame:
169+
"""Build a DataFrame from a STAC ItemSearch result."""
170+
series = [self.series_from(item) for item in search_result.items()]
171+
df = pd.DataFrame(series).reset_index(names=["item_id"])
172+
return df
173+
168174
def get_by_status(self, statuses: Iterable[str], max: Optional[int] = None) -> pd.DataFrame:
169175
if isinstance(statuses, str):
170176
statuses = {statuses}
@@ -178,16 +184,24 @@ def get_by_status(self, statuses: Iterable[str], max: Optional[int] = None) -> p
178184
max_items=max,
179185
)
180186

181-
series = [self.series_from(item) for item in search_results.items()]
187+
df = self._search_result_to_df(search_results)
182188

183-
df = pd.DataFrame(series).reset_index(names=["item_id"])
184-
if len(series) == 0:
189+
if df.shape[0] == 0:
185190
# TODO: What if default columns are overwritten by the user?
186191
df = self._normalize_df(
187192
df
188193
) # Even for an empty dataframe the default columns are required
189194
return df
190195

196+
def get_by_indices(self, indices: Iterable[Union[int, str]]) -> pd.DataFrame:
197+
search_results = self.client.search(
198+
method="GET",
199+
collections=[self.collection_id],
200+
ids=[str(i) for i in indices],
201+
)
202+
df = self._search_result_to_df(search_results)
203+
return df
204+
191205
def persist(self, df: pd.DataFrame):
192206
if not self.exists():
193207
spatial_extent = pystac.SpatialExtent([[-180, -90, 180, 90]])

tests/extra/job_management/test_job_management.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import numpy as np
2525
import pandas
2626
import pandas as pd
27+
import pandas.testing
2728
import pytest
2829
import requests
2930
import shapely.geometry
@@ -745,7 +746,6 @@ def get_status(job_id, current_status):
745746
filled_running_start_time = final_df.iloc[0]["running_start_time"]
746747
assert isinstance(rfc3339.parse_datetime(filled_running_start_time), datetime.datetime)
747748

748-
749749
def test_process_threadworker_updates(self, tmp_path, caplog):
750750
pool = _JobManagerWorkerThreadPool(max_workers=2)
751751
stats = collections.defaultdict(int)
@@ -755,8 +755,6 @@ def test_process_threadworker_updates(self, tmp_path, caplog):
755755
pool.submit_task(DummyTask("j-1", df_idx=1, db_update={"status": "queued"}, stats_update=None))
756756
pool.submit_task(DummyTask("j-2", df_idx=2, db_update=None, stats_update={"queued": 1}))
757757
pool.submit_task(DummyTask("j-3", df_idx=3, db_update=None, stats_update=None))
758-
# Invalid index (not in DB)
759-
pool.submit_task(DummyTask("j-missing", df_idx=4, db_update={"status": "created"}, stats_update=None))
760758

761759
df_initial = pd.DataFrame(
762760
{
@@ -768,23 +766,62 @@ def test_process_threadworker_updates(self, tmp_path, caplog):
768766

769767
mgr = MultiBackendJobManager(root_dir=tmp_path / "jobs")
770768

771-
with caplog.at_level(logging.ERROR):
772-
mgr._process_threadworker_updates(worker_pool=pool, job_db=job_db, stats=stats)
769+
mgr._process_threadworker_updates(worker_pool=pool, job_db=job_db, stats=stats)
773770

774771
df_final = job_db.read()
772+
pandas.testing.assert_frame_equal(
773+
df_final[["id", "status"]],
774+
pandas.DataFrame(
775+
{
776+
"id": ["j-0", "j-1", "j-2", "j-3"],
777+
"status": ["queued", "queued", "created", "created"],
778+
}
779+
),
780+
)
781+
assert stats == dirty_equals.IsPartialDict(
782+
{
783+
"queued": 2,
784+
"job_db persist": 1,
785+
}
786+
)
787+
assert caplog.messages == []
788+
789+
def test_process_threadworker_updates_unknown(self, tmp_path, caplog):
790+
pool = _JobManagerWorkerThreadPool(max_workers=2)
791+
stats = collections.defaultdict(int)
792+
793+
pool.submit_task(DummyTask("j-123", df_idx=0, db_update={"status": "queued"}, stats_update={"queued": 1}))
794+
pool.submit_task(DummyTask("j-unknown", df_idx=4, db_update={"status": "created"}, stats_update=None))
795+
796+
df_initial = pd.DataFrame(
797+
{
798+
"id": ["j-123", "j-456"],
799+
"status": ["created", "created"],
800+
}
801+
)
802+
job_db = CsvJobDatabase(tmp_path / "jobs.csv").initialize_from_df(df_initial)
775803

776-
# Assert no rows were appended
777-
assert len(df_final) == 4
804+
mgr = MultiBackendJobManager(root_dir=tmp_path / "jobs")
778805

779-
# Assert updates
780-
assert df_final.loc[0, "status"] == "queued"
781-
assert df_final.loc[1, "status"] == "queued"
782-
assert df_final.loc[2, "status"] == "created"
783-
assert df_final.loc[3, "status"] == "created"
806+
mgr._process_threadworker_updates(worker_pool=pool, job_db=job_db, stats=stats)
784807

785-
# Assert stats
786-
assert stats.get("queued", 0) == 2
787-
assert stats["job_db persist"] == 1
808+
df_final = job_db.read()
809+
pandas.testing.assert_frame_equal(
810+
df_final[["id", "status"]],
811+
pandas.DataFrame(
812+
{
813+
"id": ["j-123", "j-456"],
814+
"status": ["queued", "created"],
815+
}
816+
),
817+
)
818+
assert stats == dirty_equals.IsPartialDict(
819+
{
820+
"queued": 1,
821+
"job_db persist": 1,
822+
}
823+
)
824+
assert caplog.messages == [dirty_equals.IsStr(regex=".*Ignoring unknown.*indices.*4.*")]
788825

789826
def test_no_results_leaves_db_and_stats_untouched(self, tmp_path, caplog):
790827
pool = _JobManagerWorkerThreadPool(max_workers=2)

0 commit comments

Comments
 (0)