|
| 1 | +import concurrent.futures |
| 2 | +import datetime |
| 3 | +import logging |
| 4 | +from typing import Iterable, List |
| 5 | + |
| 6 | +import geopandas as gpd |
| 7 | +import numpy as np |
| 8 | +import pandas as pd |
| 9 | +import pystac |
| 10 | +import pystac_client |
| 11 | +import requests |
| 12 | +from shapely.geometry import mapping, shape |
| 13 | + |
| 14 | +from openeo.extra.job_management import JobDatabaseInterface, MultiBackendJobManager |
| 15 | + |
| 16 | +_log = logging.getLogger(__name__) |
| 17 | + |
| 18 | + |
| 19 | +class STACAPIJobDatabase(JobDatabaseInterface): |
| 20 | + """ |
| 21 | + Persist/load job metadata from a STAC API |
| 22 | +
|
| 23 | + Unstable API, subject to change. |
| 24 | +
|
| 25 | + :implements: :py:class:`JobDatabaseInterface` |
| 26 | + """ |
| 27 | + |
| 28 | + def __init__( |
| 29 | + self, |
| 30 | + collection_id: str, |
| 31 | + stac_root_url: str, |
| 32 | + auth: requests.auth.AuthBase, |
| 33 | + has_geometry: bool = False, |
| 34 | + geometry_column: str = "geometry", |
| 35 | + ): |
| 36 | + """ |
| 37 | + Initialize the STACAPIJobDatabase. |
| 38 | +
|
| 39 | + :param collection_id: The ID of the STAC collection. |
| 40 | + :param stac_root_url: The root URL of the STAC API. |
| 41 | + :param auth: requests AuthBase that will be used to authenticate, e.g. OAuth2ResourceOwnerPasswordCredentials |
| 42 | + :param has_geometry: Whether the job metadata supports any geometry that implements __geo_interface__. |
| 43 | + :param geometry_column: The name of the geometry column in the job metadata that implements __geo_interface__. |
| 44 | + """ |
| 45 | + self.collection_id = collection_id |
| 46 | + self.client = pystac_client.Client.open(stac_root_url) |
| 47 | + |
| 48 | + self._auth = auth |
| 49 | + self.has_geometry = has_geometry |
| 50 | + self.geometry_column = geometry_column |
| 51 | + self.base_url = stac_root_url |
| 52 | + self.bulk_size = 500 |
| 53 | + |
| 54 | + def exists(self) -> bool: |
| 55 | + return any(c.id == self.collection_id for c in self.client.get_collections()) |
| 56 | + |
| 57 | + def initialize_from_df(self, df: pd.DataFrame, *, on_exists: str = "error"): |
| 58 | + """ |
| 59 | + Initialize the job database from a given dataframe, |
| 60 | + which will be first normalized to be compatible |
| 61 | + with :py:class:`MultiBackendJobManager` usage. |
| 62 | +
|
| 63 | + :param df: dataframe with some columns your ``start_job`` callable expects |
| 64 | + :param on_exists: what to do when the job database already exists (persisted on disk): |
| 65 | + - "error": (default) raise an exception |
| 66 | + - "skip": work with existing database, ignore given dataframe and skip any initialization |
| 67 | + - "append": add given dataframe to existing database |
| 68 | +
|
| 69 | + :return: initialized job database. |
| 70 | + """ |
| 71 | + if isinstance(df, gpd.GeoDataFrame): |
| 72 | + df = df.copy() |
| 73 | + _log.warning("Job Database is initialized from GeoDataFrame. Converting geometries to GeoJSON.") |
| 74 | + self.geometry_column = df.geometry.name |
| 75 | + df[self.geometry_column] = df[self.geometry_column].apply(lambda x: mapping(x)) |
| 76 | + df = pd.DataFrame(df) |
| 77 | + self.has_geometry = True |
| 78 | + |
| 79 | + if self.exists(): |
| 80 | + if on_exists == "skip": |
| 81 | + return self |
| 82 | + elif on_exists == "error": |
| 83 | + raise FileExistsError(f"Job database {self!r} already exists.") |
| 84 | + elif on_exists == "append": |
| 85 | + existing_df = self.get_by_status([]) |
| 86 | + df = MultiBackendJobManager._normalize_df(df) |
| 87 | + df = pd.concat([existing_df, df], ignore_index=True).replace({np.nan: None}) |
| 88 | + self.persist(df) |
| 89 | + return self |
| 90 | + |
| 91 | + else: |
| 92 | + raise ValueError(f"Invalid on_exists={on_exists!r}") |
| 93 | + |
| 94 | + df = MultiBackendJobManager._normalize_df(df) |
| 95 | + self.persist(df) |
| 96 | + # Return self to allow chaining with constructor. |
| 97 | + return self |
| 98 | + |
| 99 | + def series_from(self, item: pystac.Item) -> pd.Series: |
| 100 | + """ |
| 101 | + Convert a STAC Item to a pandas.Series. |
| 102 | +
|
| 103 | + :param item: STAC Item to be converted. |
| 104 | + :return: pandas.Series |
| 105 | + """ |
| 106 | + item_dict = item.to_dict() |
| 107 | + item_id = item_dict["id"] |
| 108 | + dt = item_dict["properties"]["datetime"] |
| 109 | + |
| 110 | + return pd.Series(item_dict["properties"], name=item_id) |
| 111 | + |
| 112 | + def item_from(self, series: pd.Series) -> pystac.Item: |
| 113 | + """ |
| 114 | + Convert a pandas.Series to a STAC Item. |
| 115 | +
|
| 116 | + :param series: pandas.Series to be converted. |
| 117 | + :param geometry_name: Name of the geometry column in the series. |
| 118 | + :return: pystac.Item |
| 119 | + """ |
| 120 | + series_dict = series.to_dict() |
| 121 | + item_dict = {} |
| 122 | + item_dict.setdefault("stac_version", pystac.get_stac_version()) |
| 123 | + item_dict.setdefault("type", "Feature") |
| 124 | + item_dict.setdefault("assets", {}) |
| 125 | + item_dict.setdefault("links", []) |
| 126 | + item_dict.setdefault("properties", series_dict) |
| 127 | + |
| 128 | + dt = series_dict.get("datetime", None) |
| 129 | + if dt and item_dict["properties"].get("datetime", None) is None: |
| 130 | + dt_str = pystac.utils.datetime_to_str(dt) if isinstance(dt, datetime.datetime) else dt |
| 131 | + item_dict["properties"]["datetime"] = dt_str |
| 132 | + |
| 133 | + else: |
| 134 | + item_dict["properties"]["datetime"] = pystac.utils.datetime_to_str(datetime.datetime.now()) |
| 135 | + |
| 136 | + if self.has_geometry: |
| 137 | + item_dict["geometry"] = series[self.geometry_column] |
| 138 | + else: |
| 139 | + item_dict["geometry"] = None |
| 140 | + |
| 141 | + # from_dict handles associating any Links and Assets with the Item |
| 142 | + item_dict["id"] = series.name |
| 143 | + item = pystac.Item.from_dict(item_dict) |
| 144 | + if self.has_geometry: |
| 145 | + item.bbox = shape(series[self.geometry_column]).bounds |
| 146 | + else: |
| 147 | + item.bbox = None |
| 148 | + return item |
| 149 | + |
| 150 | + def count_by_status(self, statuses: Iterable[str] = ()) -> dict: |
| 151 | + if isinstance(statuses, str): |
| 152 | + statuses = {statuses} |
| 153 | + statuses = set(statuses) |
| 154 | + items = self.get_by_status(statuses, max=200) |
| 155 | + if items is None: |
| 156 | + return {k: 0 for k in statuses} |
| 157 | + else: |
| 158 | + return items["status"].value_counts().to_dict() |
| 159 | + |
| 160 | + def get_by_status(self, statuses: Iterable[str], max=None) -> pd.DataFrame: |
| 161 | + if isinstance(statuses, str): |
| 162 | + statuses = {statuses} |
| 163 | + statuses = set(statuses) |
| 164 | + |
| 165 | + status_filter = " OR ".join([f"\"properties.status\"='{s}'" for s in statuses]) if statuses else None |
| 166 | + search_results = self.client.search( |
| 167 | + method="GET", |
| 168 | + collections=[self.collection_id], |
| 169 | + filter=status_filter, |
| 170 | + max_items=max, |
| 171 | + ) |
| 172 | + |
| 173 | + series = [self.series_from(item) for item in search_results.items()] |
| 174 | + |
| 175 | + df = pd.DataFrame(series) |
| 176 | + if len(series) == 0: |
| 177 | + # TODO: What if default columns are overwritten by the user? |
| 178 | + df = MultiBackendJobManager._normalize_df( |
| 179 | + df |
| 180 | + ) # Even for an empty dataframe the default columns are required |
| 181 | + return df |
| 182 | + |
| 183 | + def persist(self, df: pd.DataFrame): |
| 184 | + if not self.exists(): |
| 185 | + spatial_extent = pystac.SpatialExtent([[-180, -90, 180, 90]]) |
| 186 | + temporal_extent = pystac.TemporalExtent([[None, None]]) |
| 187 | + extent = pystac.Extent(spatial=spatial_extent, temporal=temporal_extent) |
| 188 | + c = pystac.Collection(id=self.collection_id, description="STAC API job database collection.", extent=extent) |
| 189 | + self._create_collection(c) |
| 190 | + |
| 191 | + all_items = [] |
| 192 | + if not df.empty: |
| 193 | + |
| 194 | + def handle_row(series): |
| 195 | + item = self.item_from(series) |
| 196 | + all_items.append(item) |
| 197 | + |
| 198 | + df.apply(handle_row, axis=1) |
| 199 | + |
| 200 | + self._upload_items_bulk(self.collection_id, all_items) |
| 201 | + |
| 202 | + def _prepare_item(self, item: pystac.Item, collection_id: str): |
| 203 | + item.collection_id = collection_id |
| 204 | + |
| 205 | + if not item.get_links(pystac.RelType.COLLECTION): |
| 206 | + item.add_link(pystac.Link(rel=pystac.RelType.COLLECTION, target=item.collection_id)) |
| 207 | + |
| 208 | + def _ingest_bulk(self, items: List[pystac.Item]) -> dict: |
| 209 | + collection_id = items[0].collection_id |
| 210 | + if not all(i.collection_id == collection_id for i in items): |
| 211 | + raise Exception("All collection IDs should be identical for bulk ingests") |
| 212 | + |
| 213 | + url_path = f"collections/{collection_id}/bulk_items" |
| 214 | + data = {"method": "upsert", "items": {item.id: item.to_dict() for item in items}} |
| 215 | + response = requests.post(url=self.join_url(url_path), auth=self._auth, json=data) |
| 216 | + |
| 217 | + _log.info(f"HTTP response: {response.status_code} - {response.reason}: body: {response.json()}") |
| 218 | + |
| 219 | + _check_response_status(response, _EXPECTED_STATUS_POST) |
| 220 | + return response.json() |
| 221 | + |
| 222 | + def _upload_items_bulk(self, collection_id: str, items: List[pystac.Item]) -> None: |
| 223 | + chunk = [] |
| 224 | + futures = [] |
| 225 | + |
| 226 | + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: |
| 227 | + for item in items: |
| 228 | + self._prepare_item(item, collection_id) |
| 229 | + chunk.append(item) |
| 230 | + |
| 231 | + if len(chunk) == self.bulk_size: |
| 232 | + futures.append(executor.submit(self._ingest_bulk, chunk.copy())) |
| 233 | + chunk = [] |
| 234 | + |
| 235 | + if chunk: |
| 236 | + self._ingest_bulk(chunk) |
| 237 | + |
| 238 | + for _ in concurrent.futures.as_completed(futures): |
| 239 | + continue |
| 240 | + |
| 241 | + def join_url(self, url_path: str) -> str: |
| 242 | + """Create a URL from the base_url and the url_path. |
| 243 | +
|
| 244 | + :param url_path: same as in join_path |
| 245 | + :return: a URL object that represents the full URL. |
| 246 | + """ |
| 247 | + return str(self.base_url + "/" + url_path) |
| 248 | + |
| 249 | + def _create_collection(self, collection: pystac.Collection) -> dict: |
| 250 | + """Create a new collection. |
| 251 | +
|
| 252 | + :param collection: pystac.Collection object to create in the STAC API backend (or upload if you will) |
| 253 | + :raises TypeError: if collection is not a pystac.Collection. |
| 254 | + :return: dict that contains the JSON body of the HTTP response. |
| 255 | + """ |
| 256 | + |
| 257 | + if not isinstance(collection, pystac.Collection): |
| 258 | + raise TypeError( |
| 259 | + f'Argument "collection" must be of type pystac.Collection, but its type is {type(collection)=}' |
| 260 | + ) |
| 261 | + |
| 262 | + collection.validate() |
| 263 | + coll_dict = collection.to_dict() |
| 264 | + |
| 265 | + default_auth = { |
| 266 | + "_auth": { |
| 267 | + "read": ["anonymous"], |
| 268 | + "write": ["stac-openeo-admin", "stac-openeo-editor"], |
| 269 | + } |
| 270 | + } |
| 271 | + |
| 272 | + coll_dict.update(default_auth) |
| 273 | + |
| 274 | + response = requests.post(self.join_url("collections"), auth=self._auth, json=coll_dict) |
| 275 | + _check_response_status(response, _EXPECTED_STATUS_POST) |
| 276 | + |
| 277 | + return response.json() |
| 278 | + |
| 279 | + |
| 280 | +_EXPECTED_STATUS_POST = [ |
| 281 | + requests.status_codes.codes.ok, |
| 282 | + requests.status_codes.codes.created, |
| 283 | + requests.status_codes.codes.accepted, |
| 284 | +] |
| 285 | + |
| 286 | + |
| 287 | +def _check_response_status(response: requests.Response, expected_status_codes: List[int], raise_exc: bool = False): |
| 288 | + if response.status_code not in expected_status_codes: |
| 289 | + message = ( |
| 290 | + f"Expecting HTTP status to be any of {expected_status_codes} " |
| 291 | + + f"but received {response.status_code} - {response.reason}, request method={response.request.method}\n" |
| 292 | + + f"response body:\n{response.text}" |
| 293 | + ) |
| 294 | + if raise_exc: |
| 295 | + raise Exception(message) |
| 296 | + else: |
| 297 | + _log.warning(message) |
| 298 | + |
| 299 | + # Always raise errors on 4xx and 5xx status codes. |
| 300 | + response.raise_for_status() |
0 commit comments