Skip to content

Commit 4b8593c

Browse files
committed
Issue #571 improve reconstruction of geometries in CsvJobDatabase/ParquetJobDatabase
1 parent b272791 commit 4b8593c

File tree

2 files changed

+65
-22
lines changed

2 files changed

+65
-22
lines changed

openeo/extra/job_management.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pathlib import Path
99
from typing import Callable, Dict, NamedTuple, Optional, Union
1010

11+
import geopandas
1112
import pandas as pd
1213
import requests
1314
import shapely.errors
@@ -556,13 +557,13 @@ def _is_valid_wkt(self, wkt: str) -> bool:
556557

557558
def read(self) -> pd.DataFrame:
558559
df = pd.read_csv(self.path)
559-
# `df.to_csv` in `persist()` will encode geometries as WKT, so we decode that here.
560560
if (
561561
"geometry" in df.columns
562562
and df["geometry"].dtype.name != "geometry"
563563
and self._is_valid_wkt(df["geometry"].iloc[0])
564564
):
565-
df["geometry"] = df["geometry"].apply(shapely.wkt.loads)
565+
# `df.to_csv()` in `persist()` has encoded geometries as WKT, so we decode that here.
566+
df = geopandas.GeoDataFrame(df, geometry=geopandas.GeoSeries.from_wkt(df["geometry"]))
566567
return df
567568

568569
def persist(self, df: pd.DataFrame):
@@ -590,7 +591,19 @@ def exists(self) -> bool:
590591
return self.path.exists()
591592

592593
def read(self) -> pd.DataFrame:
593-
return pd.read_parquet(self.path)
594+
# Unfortunately, a naive `pandas.read_parquet()` does not easily allow
595+
# reconstructing geometries from a GeoPandas Parquet file.
596+
# And vice-versa, `geopandas.read_parquet()` does not support reading
597+
# Parquet file without geometries.
598+
# So we have to guess which case we have.
599+
# TODO is there a cleaner way to do this?
600+
import pyarrow.parquet
601+
602+
metadata = pyarrow.parquet.read_metadata(self.path)
603+
if b"geo" in metadata.metadata:
604+
return geopandas.read_parquet(self.path)
605+
else:
606+
return pd.read_parquet(self.path)
594607

595608
def persist(self, df: pd.DataFrame):
596609
self.path.parent.mkdir(parents=True, exist_ok=True)

tests/extra/test_job_management.py

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
import json
2+
import textwrap
23
import threading
34
from unittest import mock
45

6+
import geopandas
7+
58
# TODO: can we avoid using httpretty?
69
# We need it for testing the resilience, which uses an HTTPadapter with Retry
710
# but requests-mock also uses an HTTPAdapter for the mocking and basically
811
# erases the HTTPAdapter we have set up.
912
# httpretty avoids this specific problem because it mocks at the socket level,
1013
# But I would rather not have two dependencies with almost the same goal.
1114
import httpretty
15+
import pandas
1216
import pandas as pd
1317
import pytest
1418
import requests
15-
import shapely.geometry.point as shpt
19+
import shapely.geometry
1620

1721
import openeo
1822
from openeo import BatchJob
@@ -456,6 +460,26 @@ def start_job(row, connection_provider, connection, **kwargs):
456460
assert set(result.backend_name) == {"foo"}
457461

458462

463+
JOB_DB_DF_BASICS = pd.DataFrame(
464+
{
465+
"numbers": [3, 2, 1],
466+
"names": ["apple", "banana", "coconut"],
467+
}
468+
)
469+
JOB_DB_GDF_WITH_GEOMETRY = geopandas.GeoDataFrame(
470+
{
471+
"numbers": [11, 22],
472+
"geometry": [shapely.geometry.Point(1, 2), shapely.geometry.Point(2, 1)],
473+
},
474+
)
475+
JOB_DB_DF_WITH_GEOJSON_STRING = pd.DataFrame(
476+
{
477+
"numbers": [11, 22],
478+
"geometry": ['{"type":"Point","coordinates":[1,2]}', '{"type":"Point","coordinates":[1,2]}'],
479+
}
480+
)
481+
482+
459483
class TestCsvJobDatabase:
460484
def test_read_wkt(self, tmp_path):
461485
wkt_df = pd.DataFrame(
@@ -467,7 +491,7 @@ def test_read_wkt(self, tmp_path):
467491
path = tmp_path / "jobs.csv"
468492
wkt_df.to_csv(path, index=False)
469493
df = CsvJobDatabase(path).read()
470-
assert isinstance(df.geometry[0], shpt.Point)
494+
assert isinstance(df.geometry[0], shapely.geometry.Point)
471495

472496
def test_read_non_wkt(self, tmp_path):
473497
non_wkt_df = pd.DataFrame(
@@ -481,34 +505,40 @@ def test_read_non_wkt(self, tmp_path):
481505
df = CsvJobDatabase(path).read()
482506
assert isinstance(df.geometry[0], str)
483507

484-
def test_persist_and_read(self, tmp_path):
485-
orig = pd.DataFrame(
486-
{
487-
"numbers": [3, 2, 1],
488-
"names": ["apple", "banana", "coconut"],
489-
}
490-
)
491-
path = tmp_path / "jobs.csv"
508+
@pytest.mark.parametrize(
509+
["orig"],
510+
[
511+
pytest.param(JOB_DB_DF_BASICS, id="pandas basics"),
512+
pytest.param(JOB_DB_GDF_WITH_GEOMETRY, id="geopandas with geometry"),
513+
pytest.param(JOB_DB_DF_WITH_GEOJSON_STRING, id="pandas with geojson string as geometry"),
514+
],
515+
)
516+
def test_persist_and_read(self, tmp_path, orig: pandas.DataFrame):
517+
path = tmp_path / "jobs.parquet"
492518
CsvJobDatabase(path).persist(orig)
493519
assert path.exists()
494520

495521
loaded = CsvJobDatabase(path).read()
496-
assert list(loaded.dtypes) == list(orig.dtypes)
522+
assert loaded.dtypes.to_dict() == orig.dtypes.to_dict()
497523
assert loaded.equals(orig)
524+
assert type(orig) is type(loaded)
498525

499526

500527
class TestParquetJobDatabase:
501-
def test_persist_and_read(self, tmp_path):
502-
orig = pd.DataFrame(
503-
{
504-
"numbers": [3, 2, 1],
505-
"names": ["apple", "banana", "coconut"],
506-
}
507-
)
528+
@pytest.mark.parametrize(
529+
["orig"],
530+
[
531+
pytest.param(JOB_DB_DF_BASICS, id="pandas basics"),
532+
pytest.param(JOB_DB_GDF_WITH_GEOMETRY, id="geopandas with geometry"),
533+
pytest.param(JOB_DB_DF_WITH_GEOJSON_STRING, id="pandas with geojson string as geometry"),
534+
],
535+
)
536+
def test_persist_and_read(self, tmp_path, orig: pandas.DataFrame):
508537
path = tmp_path / "jobs.parquet"
509538
ParquetJobDatabase(path).persist(orig)
510539
assert path.exists()
511540

512541
loaded = ParquetJobDatabase(path).read()
513-
assert list(loaded.dtypes) == list(orig.dtypes)
542+
assert loaded.dtypes.to_dict() == orig.dtypes.to_dict()
514543
assert loaded.equals(orig)
544+
assert type(orig) is type(loaded)

0 commit comments

Comments
 (0)