Skip to content

Commit 7ae446e

Browse files
some bugs removed from STACAPIJobDatabase
1 parent 3ffcf3e commit 7ae446e

File tree

1 file changed

+83
-63
lines changed

1 file changed

+83
-63
lines changed

openeo/extra/stac_job_db.py

Lines changed: 83 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
import concurrent
22
import logging
33
from datetime import datetime
4-
from typing import List, Union, Iterable
4+
from typing import Iterable, List, Union
55

6-
import pandas as pd
76
import geopandas as gpd
7+
import pandas as pd
88
import pystac
99
import requests
1010
from pystac import Collection, Item
1111
from pystac_client import Client
1212
from requests.auth import HTTPBasicAuth
13-
from shapely.geometry import shape, mapping
13+
from shapely.geometry import mapping, shape
1414

15-
from openeo.extra.job_management import JobDatabaseInterface
15+
from openeo.extra.job_management import JobDatabaseInterface, MultiBackendJobManager
1616

1717
_log = logging.getLogger(__name__)
1818

19+
1920
class STACAPIJobDatabase(JobDatabaseInterface):
2021
"""
2122
Persist/load job metadata from a STAC API
@@ -25,7 +26,9 @@ class STACAPIJobDatabase(JobDatabaseInterface):
2526
:implements: :py:class:`JobDatabaseInterface`
2627
"""
2728

28-
def __init__(self, collection_id: str, stac_root_url: str, auth: requests.auth.AuthBase):
29+
def __init__(
30+
self, collection_id: str, stac_root_url: str, auth: requests.auth.AuthBase, has_geometry: bool = False
31+
):
2932
"""
3033
Initialize the STACAPIJobDatabase.
3134
@@ -37,17 +40,51 @@ def __init__(self, collection_id: str, stac_root_url: str, auth: requests.auth.A
3740
self.client = Client.open(stac_root_url)
3841

3942
self._auth = auth
43+
self.has_geometry = has_geometry
44+
self.geometry_column = "geometry"
4045
self.base_url = stac_root_url
4146
self.bulk_size = 500
42-
#self.collection = self.client.get_collection(collection_id)
4347

4448

4549

4650
def exists(self) -> bool:
47-
return len([c.id for c in self.client.get_collections() if c.id == self.collection_id ]) >0
51+
return len([c.id for c in self.client.get_collections() if c.id == self.collection_id]) > 0
52+
53+
def initialize_from_df(self, df: pd.DataFrame, *, on_exists: str = "error"):
54+
"""
55+
Initialize the job database from a given dataframe,
56+
which will be first normalized to be compatible
57+
with :py:class:`MultiBackendJobManager` usage.
4858
49-
@staticmethod
50-
def series_from(item):
59+
:param df: dataframe with some columns your ``start_job`` callable expects
60+
:param on_exists: what to do when the job database already exists (persisted on disk):
61+
- "error": (default) raise an exception
62+
- "skip": work with existing database, ignore given dataframe and skip any initialization
63+
64+
:return: initialized job database.
65+
"""
66+
if self.exists():
67+
if on_exists == "skip":
68+
return self
69+
elif on_exists == "error":
70+
raise FileExistsError(f"Job database {self!r} already exists.")
71+
else:
72+
# TODO handle other on_exists modes: e.g. overwrite, merge, ...
73+
raise ValueError(f"Invalid on_exists={on_exists!r}")
74+
75+
if isinstance(df, gpd.GeoDataFrame):
76+
_log.warning("Job Database is initialized from GeoDataFrame. Converting geometries to GeoJSON.")
77+
self.geometry_column = df.geometry.name
78+
df["geometry"] = df["geometry"].apply(lambda x: mapping(x))
79+
df = pd.DataFrame(df)
80+
self.has_geometry = True
81+
82+
df = MultiBackendJobManager._normalize_df(df)
83+
self.persist(df)
84+
# Return self to allow chaining with constructor.
85+
return self
86+
87+
def series_from(self, item: pystac.Item) -> pd.Series:
5188
"""
5289
Convert a STAC Item to a pandas.Series.
5390
@@ -56,20 +93,12 @@ def series_from(item):
5693
"""
5794
item_dict = item.to_dict()
5895
item_id = item_dict["id"]
59-
print(item_dict)
60-
# Promote datetime
6196
dt = item_dict["properties"]["datetime"]
6297
item_dict["datetime"] = pystac.utils.str_to_datetime(dt)
63-
#del item_dict["properties"]["datetime"]
64-
6598

66-
# Convert geojson geom into shapely.Geometry
67-
item_dict["properties"]["geometry"] = shape(item_dict["geometry"])
68-
#item_dict["properties"]["name"] = item_id
6999
return pd.Series(item_dict["properties"], name=item_id)
70100

71-
@staticmethod
72-
def item_from(series: pd.Series, geometry_name="geometry"):
101+
def item_from(self, series: pd.Series, geometry_name: str = "geometry") -> pystac.Item:
73102
"""
74103
Convert a pandas.Series to a STAC Item.
75104
@@ -93,64 +122,70 @@ def item_from(series: pd.Series, geometry_name="geometry"):
93122
else:
94123
item_dict["properties"]["datetime"] = pystac.utils.datetime_to_str(datetime.now())
95124

96-
item_dict["geometry"] = mapping(series[geometry_name])
97-
del series_dict[geometry_name]
125+
if self.has_geometry:
126+
item_dict["geometry"] = series[geometry_name]
127+
else:
128+
item_dict["geometry"] = None
98129

99130
# from_dict handles associating any Links and Assets with the Item
100-
item_dict['id'] = series.name
131+
item_dict["id"] = series.name
101132
item = pystac.Item.from_dict(item_dict)
102-
item.bbox = series[geometry_name].bounds
133+
if self.has_geometry:
134+
item.bbox = shape(series[geometry_name]).bounds
135+
else:
136+
item.bbox = None
103137
return item
104138

105-
def count_by_status(self, statuses: List[str]) -> dict:
106-
#todo: replace with use of stac aggregation extension
107-
#example of how what an aggregation call looks like: https://stac-openeo-dev.vgt.vito.be/collections/copernicus_r_utm-wgs84_10_m_hrvpp-vpp_p_2017-now_v01/aggregate?aggregations=total_count&filter=description%3DSOSD&filter-lang=cql2-text
139+
def count_by_status(self, statuses: Iterable[str] = ()) -> dict:
108140
items = self.get_by_status(statuses,max=200)
109141
if items is None:
110-
return { k:0 for k in statuses}
142+
return {k: 0 for k in statuses}
111143
else:
112144
return items["status"].value_counts().to_dict()
113145

114-
def get_by_status(self, statuses: List[str], max=None) -> pd.DataFrame:
146+
def get_by_status(self, statuses: Iterable[str], max=None) -> pd.DataFrame:
115147

116-
if isinstance(statuses,str):
117-
statuses = [statuses]
148+
if isinstance(statuses, str):
149+
statuses = {statuses}
150+
statuses = set(statuses)
118151

119-
status_filter = " OR ".join([ f"\"properties.status\"={s}" for s in statuses])
152+
status_filter = " OR ".join([f"\"properties.status\"='{s}'" for s in statuses]) if statuses else None
120153
search_results = self.client.search(
121154
method="GET",
122155
collections=[self.collection_id],
123156
filter=status_filter,
124157
max_items=max,
125-
fields=["properties"]
126158
)
127159

128-
crs = "EPSG:4326"
129-
series = [STACAPIJobDatabase.series_from(item) for item in search_results.items()]
160+
series = [self.series_from(item) for item in search_results.items()]
161+
162+
df = pd.DataFrame(series)
130163
if len(series) == 0:
131-
return None
132-
gdf = gpd.GeoDataFrame(series, crs=crs)
133-
# TODO how to know the proper name of the geometry column?
134-
# this only matters for the udp based version probably
135-
#gdf.rename_geometry("polygon", inplace=True)
136-
return gdf
164+
# TODO: What if default columns are overwritten by the user?
165+
df = MultiBackendJobManager._normalize_df(
166+
df
167+
) # Even for an empty dataframe the default columns are required
168+
return df
137169

138170

139171

140172
def persist(self, df: pd.DataFrame):
141-
142173
if not self.exists():
143-
c= pystac.Collection(id=self.collection_id,description="test collection for jobs",extent=pystac.Extent(spatial=pystac.SpatialExtent(bboxes=[list(df.total_bounds)]),temporal=pystac.TemporalExtent(intervals=[None,None])))
174+
spatial_extent = pystac.SpatialExtent([[-180, -90, 180, 90]])
175+
temporal_extent = pystac.TemporalExtent([[None, None]])
176+
extent = pystac.Extent(spatial=spatial_extent, temporal=temporal_extent)
177+
c = pystac.Collection(id=self.collection_id, description="STAC API job database collection.", extent=extent)
144178
self._create_collection(c)
145179

146180
all_items = []
147-
def handle_row(series):
148-
item = STACAPIJobDatabase.item_from(series,df.geometry.name)
149-
#upload item
150-
all_items.append(item)
181+
if not df.empty:
182+
183+
def handle_row(series):
184+
item = self.item_from(series, self.geometry_column)
185+
all_items.append(item)
151186

152187

153-
df.apply(handle_row, axis=1)
188+
df.apply(handle_row, axis=1)
154189

155190
self._upload_items_bulk(self.collection_id, all_items)
156191

@@ -176,27 +211,19 @@ def _ingest_bulk(self, items: Iterable[Item]) -> dict:
176211

177212
def _upload_items_bulk(self, collection_id: str, items: Iterable[Item]) -> None:
178213
chunk = []
179-
chunk_start = 0
180-
chunk_end = 0
181214
futures = []
182215

183216
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
184-
for index, item in enumerate(items):
217+
for item in items:
185218
self._prepare_item(item, collection_id)
186219
# item.validate()
187220
chunk.append(item)
188221

189222
if len(chunk) == self.bulk_size:
190-
chunk_end = index + 1
191-
chunk_start = chunk_end - len(chunk) + 1
192-
193223
futures.append(executor.submit(self._ingest_bulk, chunk.copy()))
194224
chunk = []
195225

196226
if chunk:
197-
chunk_end = index + 1
198-
chunk_start = chunk_end - len(chunk) + 1
199-
200227
self._ingest_bulk(chunk)
201228

202229
for _ in concurrent.futures.as_completed(futures):
@@ -241,18 +268,11 @@ def _create_collection(self, collection: Collection) -> dict:
241268
return response.json()
242269

243270

244-
_EXPECTED_STATUS_GET = [requests.status_codes.codes.ok]
245271
_EXPECTED_STATUS_POST = [
246272
requests.status_codes.codes.ok,
247273
requests.status_codes.codes.created,
248274
requests.status_codes.codes.accepted,
249275
]
250-
_EXPECTED_STATUS_PUT = [
251-
requests.status_codes.codes.ok,
252-
requests.status_codes.codes.created,
253-
requests.status_codes.codes.accepted,
254-
requests.status_codes.codes.no_content,
255-
]
256276

257277
def _check_response_status(response: requests.Response, expected_status_codes: list[int], raise_exc: bool = False):
258278
if response.status_code not in expected_status_codes:
@@ -267,4 +287,4 @@ def _check_response_status(response: requests.Response, expected_status_codes: l
267287
_log.warning(message)
268288

269289
# Always raise errors on 4xx and 5xx status codes.
270-
response.raise_for_status()
290+
response.raise_for_status()

0 commit comments

Comments
 (0)