Skip to content

Commit 7064cf3

Browse files
Merge pull request #619 from Open-EO/stac_jobdb
add job database implementation that uses stac
2 parents e106045 + 8bb7f37 commit 7064cf3

File tree

5 files changed

+693
-1
lines changed

5 files changed

+693
-1
lines changed

docs/installation.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ For example:
9292
- ``rioxarray`` for GeoTIFF support in the assert helpers from ``openeo.testing.results``
9393
- ``geopandas`` for working with dataframes with geospatial support,
9494
(e.g. with :py:class:`~openeo.extra.job_management.MultiBackendJobManager`)
95+
- ``pystac_client`` for creating a STAC API Job Database (e.g. with :py:class:`~openeo.extra.job_management.stac_job_db.STACAPIJobDatabase`)
9596

9697

9798
Enabling additional features

openeo/extra/job_management/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ def get_by_status(self, statuses: List[str], max=None) -> pd.DataFrame:
104104
"""
105105
...
106106

107-
108107
def _start_job_default(row: pd.Series, connection: Connection, *args, **kwargs):
109108
raise NotImplementedError("No 'start_job' callable provided")
110109

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
import concurrent.futures
2+
import datetime
3+
import logging
4+
from typing import Iterable, List
5+
6+
import geopandas as gpd
7+
import numpy as np
8+
import pandas as pd
9+
import pystac
10+
import pystac_client
11+
import requests
12+
from shapely.geometry import mapping, shape
13+
14+
from openeo.extra.job_management import JobDatabaseInterface, MultiBackendJobManager
15+
16+
_log = logging.getLogger(__name__)
17+
18+
19+
class STACAPIJobDatabase(JobDatabaseInterface):
20+
"""
21+
Persist/load job metadata from a STAC API
22+
23+
Unstable API, subject to change.
24+
25+
:implements: :py:class:`JobDatabaseInterface`
26+
"""
27+
28+
def __init__(
29+
self,
30+
collection_id: str,
31+
stac_root_url: str,
32+
auth: requests.auth.AuthBase,
33+
has_geometry: bool = False,
34+
geometry_column: str = "geometry",
35+
):
36+
"""
37+
Initialize the STACAPIJobDatabase.
38+
39+
:param collection_id: The ID of the STAC collection.
40+
:param stac_root_url: The root URL of the STAC API.
41+
:param auth: requests AuthBase that will be used to authenticate, e.g. OAuth2ResourceOwnerPasswordCredentials
42+
:param has_geometry: Whether the job metadata supports any geometry that implements __geo_interface__.
43+
:param geometry_column: The name of the geometry column in the job metadata that implements __geo_interface__.
44+
"""
45+
self.collection_id = collection_id
46+
self.client = pystac_client.Client.open(stac_root_url)
47+
48+
self._auth = auth
49+
self.has_geometry = has_geometry
50+
self.geometry_column = geometry_column
51+
self.base_url = stac_root_url
52+
self.bulk_size = 500
53+
54+
def exists(self) -> bool:
55+
return any(c.id == self.collection_id for c in self.client.get_collections())
56+
57+
def initialize_from_df(self, df: pd.DataFrame, *, on_exists: str = "error"):
58+
"""
59+
Initialize the job database from a given dataframe,
60+
which will be first normalized to be compatible
61+
with :py:class:`MultiBackendJobManager` usage.
62+
63+
:param df: dataframe with some columns your ``start_job`` callable expects
64+
:param on_exists: what to do when the job database already exists (persisted on disk):
65+
- "error": (default) raise an exception
66+
- "skip": work with existing database, ignore given dataframe and skip any initialization
67+
- "append": add given dataframe to existing database
68+
69+
:return: initialized job database.
70+
"""
71+
if isinstance(df, gpd.GeoDataFrame):
72+
df = df.copy()
73+
_log.warning("Job Database is initialized from GeoDataFrame. Converting geometries to GeoJSON.")
74+
self.geometry_column = df.geometry.name
75+
df[self.geometry_column] = df[self.geometry_column].apply(lambda x: mapping(x))
76+
df = pd.DataFrame(df)
77+
self.has_geometry = True
78+
79+
if self.exists():
80+
if on_exists == "skip":
81+
return self
82+
elif on_exists == "error":
83+
raise FileExistsError(f"Job database {self!r} already exists.")
84+
elif on_exists == "append":
85+
existing_df = self.get_by_status([])
86+
df = MultiBackendJobManager._normalize_df(df)
87+
df = pd.concat([existing_df, df], ignore_index=True).replace({np.nan: None})
88+
self.persist(df)
89+
return self
90+
91+
else:
92+
raise ValueError(f"Invalid on_exists={on_exists!r}")
93+
94+
df = MultiBackendJobManager._normalize_df(df)
95+
self.persist(df)
96+
# Return self to allow chaining with constructor.
97+
return self
98+
99+
def series_from(self, item: pystac.Item) -> pd.Series:
100+
"""
101+
Convert a STAC Item to a pandas.Series.
102+
103+
:param item: STAC Item to be converted.
104+
:return: pandas.Series
105+
"""
106+
item_dict = item.to_dict()
107+
item_id = item_dict["id"]
108+
dt = item_dict["properties"]["datetime"]
109+
110+
return pd.Series(item_dict["properties"], name=item_id)
111+
112+
def item_from(self, series: pd.Series) -> pystac.Item:
113+
"""
114+
Convert a pandas.Series to a STAC Item.
115+
116+
:param series: pandas.Series to be converted.
117+
:param geometry_name: Name of the geometry column in the series.
118+
:return: pystac.Item
119+
"""
120+
series_dict = series.to_dict()
121+
item_dict = {}
122+
item_dict.setdefault("stac_version", pystac.get_stac_version())
123+
item_dict.setdefault("type", "Feature")
124+
item_dict.setdefault("assets", {})
125+
item_dict.setdefault("links", [])
126+
item_dict.setdefault("properties", series_dict)
127+
128+
dt = series_dict.get("datetime", None)
129+
if dt and item_dict["properties"].get("datetime", None) is None:
130+
dt_str = pystac.utils.datetime_to_str(dt) if isinstance(dt, datetime.datetime) else dt
131+
item_dict["properties"]["datetime"] = dt_str
132+
133+
else:
134+
item_dict["properties"]["datetime"] = pystac.utils.datetime_to_str(datetime.datetime.now())
135+
136+
if self.has_geometry:
137+
item_dict["geometry"] = series[self.geometry_column]
138+
else:
139+
item_dict["geometry"] = None
140+
141+
# from_dict handles associating any Links and Assets with the Item
142+
item_dict["id"] = series.name
143+
item = pystac.Item.from_dict(item_dict)
144+
if self.has_geometry:
145+
item.bbox = shape(series[self.geometry_column]).bounds
146+
else:
147+
item.bbox = None
148+
return item
149+
150+
def count_by_status(self, statuses: Iterable[str] = ()) -> dict:
151+
if isinstance(statuses, str):
152+
statuses = {statuses}
153+
statuses = set(statuses)
154+
items = self.get_by_status(statuses, max=200)
155+
if items is None:
156+
return {k: 0 for k in statuses}
157+
else:
158+
return items["status"].value_counts().to_dict()
159+
160+
def get_by_status(self, statuses: Iterable[str], max=None) -> pd.DataFrame:
161+
if isinstance(statuses, str):
162+
statuses = {statuses}
163+
statuses = set(statuses)
164+
165+
status_filter = " OR ".join([f"\"properties.status\"='{s}'" for s in statuses]) if statuses else None
166+
search_results = self.client.search(
167+
method="GET",
168+
collections=[self.collection_id],
169+
filter=status_filter,
170+
max_items=max,
171+
)
172+
173+
series = [self.series_from(item) for item in search_results.items()]
174+
175+
df = pd.DataFrame(series)
176+
if len(series) == 0:
177+
# TODO: What if default columns are overwritten by the user?
178+
df = MultiBackendJobManager._normalize_df(
179+
df
180+
) # Even for an empty dataframe the default columns are required
181+
return df
182+
183+
def persist(self, df: pd.DataFrame):
184+
if not self.exists():
185+
spatial_extent = pystac.SpatialExtent([[-180, -90, 180, 90]])
186+
temporal_extent = pystac.TemporalExtent([[None, None]])
187+
extent = pystac.Extent(spatial=spatial_extent, temporal=temporal_extent)
188+
c = pystac.Collection(id=self.collection_id, description="STAC API job database collection.", extent=extent)
189+
self._create_collection(c)
190+
191+
all_items = []
192+
if not df.empty:
193+
194+
def handle_row(series):
195+
item = self.item_from(series)
196+
all_items.append(item)
197+
198+
df.apply(handle_row, axis=1)
199+
200+
self._upload_items_bulk(self.collection_id, all_items)
201+
202+
def _prepare_item(self, item: pystac.Item, collection_id: str):
203+
item.collection_id = collection_id
204+
205+
if not item.get_links(pystac.RelType.COLLECTION):
206+
item.add_link(pystac.Link(rel=pystac.RelType.COLLECTION, target=item.collection_id))
207+
208+
def _ingest_bulk(self, items: List[pystac.Item]) -> dict:
209+
collection_id = items[0].collection_id
210+
if not all(i.collection_id == collection_id for i in items):
211+
raise Exception("All collection IDs should be identical for bulk ingests")
212+
213+
url_path = f"collections/{collection_id}/bulk_items"
214+
data = {"method": "upsert", "items": {item.id: item.to_dict() for item in items}}
215+
response = requests.post(url=self.join_url(url_path), auth=self._auth, json=data)
216+
217+
_log.info(f"HTTP response: {response.status_code} - {response.reason}: body: {response.json()}")
218+
219+
_check_response_status(response, _EXPECTED_STATUS_POST)
220+
return response.json()
221+
222+
def _upload_items_bulk(self, collection_id: str, items: List[pystac.Item]) -> None:
223+
chunk = []
224+
futures = []
225+
226+
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
227+
for item in items:
228+
self._prepare_item(item, collection_id)
229+
chunk.append(item)
230+
231+
if len(chunk) == self.bulk_size:
232+
futures.append(executor.submit(self._ingest_bulk, chunk.copy()))
233+
chunk = []
234+
235+
if chunk:
236+
self._ingest_bulk(chunk)
237+
238+
for _ in concurrent.futures.as_completed(futures):
239+
continue
240+
241+
def join_url(self, url_path: str) -> str:
242+
"""Create a URL from the base_url and the url_path.
243+
244+
:param url_path: same as in join_path
245+
:return: a URL object that represents the full URL.
246+
"""
247+
return str(self.base_url + "/" + url_path)
248+
249+
def _create_collection(self, collection: pystac.Collection) -> dict:
250+
"""Create a new collection.
251+
252+
:param collection: pystac.Collection object to create in the STAC API backend (or upload if you will)
253+
:raises TypeError: if collection is not a pystac.Collection.
254+
:return: dict that contains the JSON body of the HTTP response.
255+
"""
256+
257+
if not isinstance(collection, pystac.Collection):
258+
raise TypeError(
259+
f'Argument "collection" must be of type pystac.Collection, but its type is {type(collection)=}'
260+
)
261+
262+
collection.validate()
263+
coll_dict = collection.to_dict()
264+
265+
default_auth = {
266+
"_auth": {
267+
"read": ["anonymous"],
268+
"write": ["stac-openeo-admin", "stac-openeo-editor"],
269+
}
270+
}
271+
272+
coll_dict.update(default_auth)
273+
274+
response = requests.post(self.join_url("collections"), auth=self._auth, json=coll_dict)
275+
_check_response_status(response, _EXPECTED_STATUS_POST)
276+
277+
return response.json()
278+
279+
280+
_EXPECTED_STATUS_POST = [
281+
requests.status_codes.codes.ok,
282+
requests.status_codes.codes.created,
283+
requests.status_codes.codes.accepted,
284+
]
285+
286+
287+
def _check_response_status(response: requests.Response, expected_status_codes: List[int], raise_exc: bool = False):
288+
if response.status_code not in expected_status_codes:
289+
message = (
290+
f"Expecting HTTP status to be any of {expected_status_codes} "
291+
+ f"but received {response.status_code} - {response.reason}, request method={response.request.method}\n"
292+
+ f"response body:\n{response.text}"
293+
)
294+
if raise_exc:
295+
raise Exception(message)
296+
else:
297+
_log.warning(message)
298+
299+
# Always raise errors on 4xx and 5xx status codes.
300+
response.raise_for_status()

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"dirty_equals>=0.8.0",
3232
"pyarrow>=10.0.1", # For Parquet read/write support in pandas
3333
"python-dateutil>=2.7.0",
34+
"pystac-client>=0.7.5",
3435
]
3536

3637
docs_require = [

0 commit comments

Comments
 (0)