11import concurrent
22import logging
33from datetime import datetime
4- from typing import List , Union , Iterable
4+ from typing import Iterable , List , Union
55
6- import pandas as pd
76import geopandas as gpd
7+ import pandas as pd
88import pystac
99import requests
1010from pystac import Collection , Item
1111from pystac_client import Client
1212from requests .auth import HTTPBasicAuth
13- from shapely .geometry import shape , mapping
13+ from shapely .geometry import mapping , shape
1414
15- from openeo .extra .job_management import JobDatabaseInterface
15+ from openeo .extra .job_management import JobDatabaseInterface , MultiBackendJobManager
1616
1717_log = logging .getLogger (__name__ )
1818
19+
1920class STACAPIJobDatabase (JobDatabaseInterface ):
2021 """
2122 Persist/load job metadata from a STAC API
@@ -25,7 +26,9 @@ class STACAPIJobDatabase(JobDatabaseInterface):
2526 :implements: :py:class:`JobDatabaseInterface`
2627 """
2728
28- def __init__ (self , collection_id : str , stac_root_url : str , auth : requests .auth .AuthBase ):
29+ def __init__ (
30+ self , collection_id : str , stac_root_url : str , auth : requests .auth .AuthBase , has_geometry : bool = False
31+ ):
2932 """
3033 Initialize the STACAPIJobDatabase.
3134
@@ -37,17 +40,51 @@ def __init__(self, collection_id: str, stac_root_url: str, auth: requests.auth.A
3740 self .client = Client .open (stac_root_url )
3841
3942 self ._auth = auth
43+ self .has_geometry = has_geometry
44+ self .geometry_column = "geometry"
4045 self .base_url = stac_root_url
4146 self .bulk_size = 500
42- #self.collection = self.client.get_collection(collection_id)
4347
4448
4549
4650 def exists (self ) -> bool :
47- return len ([c .id for c in self .client .get_collections () if c .id == self .collection_id ]) > 0
51+ return len ([c .id for c in self .client .get_collections () if c .id == self .collection_id ]) > 0
52+
53+ def initialize_from_df (self , df : pd .DataFrame , * , on_exists : str = "error" ):
54+ """
55+ Initialize the job database from a given dataframe,
56+ which will be first normalized to be compatible
57+ with :py:class:`MultiBackendJobManager` usage.
4858
49- @staticmethod
50- def series_from (item ):
59+ :param df: dataframe with some columns your ``start_job`` callable expects
60+ :param on_exists: what to do when the job database already exists (persisted on disk):
61+ - "error": (default) raise an exception
62+ - "skip": work with existing database, ignore given dataframe and skip any initialization
63+
64+ :return: initialized job database.
65+ """
66+ if self .exists ():
67+ if on_exists == "skip" :
68+ return self
69+ elif on_exists == "error" :
70+ raise FileExistsError (f"Job database { self !r} already exists." )
71+ else :
72+ # TODO handle other on_exists modes: e.g. overwrite, merge, ...
73+ raise ValueError (f"Invalid on_exists={ on_exists !r} " )
74+
75+ if isinstance (df , gpd .GeoDataFrame ):
76+ _log .warning ("Job Database is initialized from GeoDataFrame. Converting geometries to GeoJSON." )
77+ self .geometry_column = df .geometry .name
78+ df ["geometry" ] = df ["geometry" ].apply (lambda x : mapping (x ))
79+ df = pd .DataFrame (df )
80+ self .has_geometry = True
81+
82+ df = MultiBackendJobManager ._normalize_df (df )
83+ self .persist (df )
84+ # Return self to allow chaining with constructor.
85+ return self
86+
87+ def series_from (self , item : pystac .Item ) -> pd .Series :
5188 """
5289 Convert a STAC Item to a pandas.Series.
5390
@@ -56,20 +93,12 @@ def series_from(item):
5693 """
5794 item_dict = item .to_dict ()
5895 item_id = item_dict ["id" ]
59- print (item_dict )
60- # Promote datetime
6196 dt = item_dict ["properties" ]["datetime" ]
6297 item_dict ["datetime" ] = pystac .utils .str_to_datetime (dt )
63- #del item_dict["properties"]["datetime"]
64-
6598
66- # Convert geojson geom into shapely.Geometry
67- item_dict ["properties" ]["geometry" ] = shape (item_dict ["geometry" ])
68- #item_dict["properties"]["name"] = item_id
6999 return pd .Series (item_dict ["properties" ], name = item_id )
70100
71- @staticmethod
72- def item_from (series : pd .Series , geometry_name = "geometry" ):
101+ def item_from (self , series : pd .Series , geometry_name : str = "geometry" ) -> pystac .Item :
73102 """
74103 Convert a pandas.Series to a STAC Item.
75104
@@ -93,64 +122,70 @@ def item_from(series: pd.Series, geometry_name="geometry"):
93122 else :
94123 item_dict ["properties" ]["datetime" ] = pystac .utils .datetime_to_str (datetime .now ())
95124
96- item_dict ["geometry" ] = mapping (series [geometry_name ])
97- del series_dict [geometry_name ]
125+ if self .has_geometry :
126+ item_dict ["geometry" ] = series [geometry_name ]
127+ else :
128+ item_dict ["geometry" ] = None
98129
99130 # from_dict handles associating any Links and Assets with the Item
100- item_dict ['id' ] = series .name
131+ item_dict ["id" ] = series .name
101132 item = pystac .Item .from_dict (item_dict )
102- item .bbox = series [geometry_name ].bounds
133+ if self .has_geometry :
134+ item .bbox = shape (series [geometry_name ]).bounds
135+ else :
136+ item .bbox = None
103137 return item
104138
105- def count_by_status (self , statuses : List [str ]) -> dict :
106- #todo: replace with use of stac aggregation extension
107- #example of how what an aggregation call looks like: https://stac-openeo-dev.vgt.vito.be/collections/copernicus_r_utm-wgs84_10_m_hrvpp-vpp_p_2017-now_v01/aggregate?aggregations=total_count&filter=description%3DSOSD&filter-lang=cql2-text
139+ def count_by_status (self , statuses : Iterable [str ] = ()) -> dict :
108140 items = self .get_by_status (statuses ,max = 200 )
109141 if items is None :
110- return { k : 0 for k in statuses }
142+ return {k : 0 for k in statuses }
111143 else :
112144 return items ["status" ].value_counts ().to_dict ()
113145
114- def get_by_status (self , statuses : List [str ], max = None ) -> pd .DataFrame :
146+ def get_by_status (self , statuses : Iterable [str ], max = None ) -> pd .DataFrame :
115147
116- if isinstance (statuses ,str ):
117- statuses = [statuses ]
148+ if isinstance (statuses , str ):
149+ statuses = {statuses }
150+ statuses = set (statuses )
118151
119- status_filter = " OR " .join ([ f"\" properties.status\" ={ s } " for s in statuses ])
152+ status_filter = " OR " .join ([f"\" properties.status\" =' { s } ' " for s in statuses ]) if statuses else None
120153 search_results = self .client .search (
121154 method = "GET" ,
122155 collections = [self .collection_id ],
123156 filter = status_filter ,
124157 max_items = max ,
125- fields = ["properties" ]
126158 )
127159
128- crs = "EPSG:4326"
129- series = [STACAPIJobDatabase .series_from (item ) for item in search_results .items ()]
160+ series = [self .series_from (item ) for item in search_results .items ()]
161+
162+ df = pd .DataFrame (series )
130163 if len (series ) == 0 :
131- return None
132- gdf = gpd .GeoDataFrame (series , crs = crs )
133- # TODO how to know the proper name of the geometry column?
134- # this only matters for the udp based version probably
135- #gdf.rename_geometry("polygon", inplace=True)
136- return gdf
164+ # TODO: What if default columns are overwritten by the user?
165+ df = MultiBackendJobManager ._normalize_df (
166+ df
167+ ) # Even for an empty dataframe the default columns are required
168+ return df
137169
138170
139171
140172 def persist (self , df : pd .DataFrame ):
141-
142173 if not self .exists ():
143- c = pystac .Collection (id = self .collection_id ,description = "test collection for jobs" ,extent = pystac .Extent (spatial = pystac .SpatialExtent (bboxes = [list (df .total_bounds )]),temporal = pystac .TemporalExtent (intervals = [None ,None ])))
174+ spatial_extent = pystac .SpatialExtent ([[- 180 , - 90 , 180 , 90 ]])
175+ temporal_extent = pystac .TemporalExtent ([[None , None ]])
176+ extent = pystac .Extent (spatial = spatial_extent , temporal = temporal_extent )
177+ c = pystac .Collection (id = self .collection_id , description = "STAC API job database collection." , extent = extent )
144178 self ._create_collection (c )
145179
146180 all_items = []
147- def handle_row (series ):
148- item = STACAPIJobDatabase .item_from (series ,df .geometry .name )
149- #upload item
150- all_items .append (item )
181+ if not df .empty :
182+
183+ def handle_row (series ):
184+ item = self .item_from (series , self .geometry_column )
185+ all_items .append (item )
151186
152187
153- df .apply (handle_row , axis = 1 )
188+ df .apply (handle_row , axis = 1 )
154189
155190 self ._upload_items_bulk (self .collection_id , all_items )
156191
@@ -176,27 +211,19 @@ def _ingest_bulk(self, items: Iterable[Item]) -> dict:
176211
177212 def _upload_items_bulk (self , collection_id : str , items : Iterable [Item ]) -> None :
178213 chunk = []
179- chunk_start = 0
180- chunk_end = 0
181214 futures = []
182215
183216 with concurrent .futures .ThreadPoolExecutor (max_workers = 4 ) as executor :
184- for index , item in enumerate ( items ) :
217+ for item in items :
185218 self ._prepare_item (item , collection_id )
186219 # item.validate()
187220 chunk .append (item )
188221
189222 if len (chunk ) == self .bulk_size :
190- chunk_end = index + 1
191- chunk_start = chunk_end - len (chunk ) + 1
192-
193223 futures .append (executor .submit (self ._ingest_bulk , chunk .copy ()))
194224 chunk = []
195225
196226 if chunk :
197- chunk_end = index + 1
198- chunk_start = chunk_end - len (chunk ) + 1
199-
200227 self ._ingest_bulk (chunk )
201228
202229 for _ in concurrent .futures .as_completed (futures ):
@@ -241,18 +268,11 @@ def _create_collection(self, collection: Collection) -> dict:
241268 return response .json ()
242269
243270
244- _EXPECTED_STATUS_GET = [requests .status_codes .codes .ok ]
245271_EXPECTED_STATUS_POST = [
246272 requests .status_codes .codes .ok ,
247273 requests .status_codes .codes .created ,
248274 requests .status_codes .codes .accepted ,
249275]
250- _EXPECTED_STATUS_PUT = [
251- requests .status_codes .codes .ok ,
252- requests .status_codes .codes .created ,
253- requests .status_codes .codes .accepted ,
254- requests .status_codes .codes .no_content ,
255- ]
256276
257277def _check_response_status (response : requests .Response , expected_status_codes : list [int ], raise_exc : bool = False ):
258278 if response .status_code not in expected_status_codes :
@@ -267,4 +287,4 @@ def _check_response_status(response: requests.Response, expected_status_codes: l
267287 _log .warning (message )
268288
269289 # Always raise errors on 4xx and 5xx status codes.
270- response .raise_for_status ()
290+ response .raise_for_status ()
0 commit comments