Skip to content

Commit fcbddf2

Browse files
Register arrow import/export dispatch to make p2p shuffle work (#295)
1 parent 7679bf9 commit fcbddf2

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

dask_geopandas/backends.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import uuid
22
from packaging.version import Version
33

4+
import dask
45
from dask import config
56

67
# Check if dask-dataframe is using dask-expr (default of None means True as well)
@@ -84,3 +85,36 @@ def get_pyarrow_schema_geopandas(obj):
8485
for col in obj.columns[obj.dtypes == "geometry"]:
8586
df[col] = obj[col].to_wkb()
8687
return pa.Schema.from_pandas(df)
88+
89+
90+
if Version(dask.__version__) >= Version("2023.6.1"):
91+
from dask.dataframe.dispatch import (
92+
from_pyarrow_table_dispatch,
93+
to_pyarrow_table_dispatch,
94+
)
95+
96+
@to_pyarrow_table_dispatch.register((geopandas.GeoDataFrame,))
97+
def get_pyarrow_table_from_geopandas(obj, **kwargs):
98+
# `kwargs` must be supported by `pyarrow.Table.from_pandas`
99+
import pyarrow as pa
100+
101+
if Version(geopandas.__version__).major < 1:
102+
return pa.Table.from_pandas(obj.to_wkb(), **kwargs)
103+
else:
104+
# TODO handle kwargs?
105+
return pa.table(obj.to_arrow())
106+
107+
@from_pyarrow_table_dispatch.register((geopandas.GeoDataFrame,))
108+
def get_geopandas_geodataframe_from_pyarrow(meta, table, **kwargs):
109+
# `kwargs` must be supported by `pyarrow.Table.to_pandas`
110+
if Version(geopandas.__version__).major < 1:
111+
df = table.to_pandas(**kwargs)
112+
113+
for col in meta.columns[meta.dtypes == "geometry"]:
114+
df[col] = geopandas.GeoSeries.from_wkb(df[col], crs=meta[col].crs)
115+
116+
return df
117+
118+
else:
119+
# TODO handle kwargs?
120+
return geopandas.GeoDataFrame.from_arrow(table)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from packaging.version import Version
2+
3+
import geopandas
4+
5+
import dask_geopandas
6+
7+
import pytest
8+
from geopandas.testing import assert_geodataframe_equal
9+
10+
distributed = pytest.importorskip("distributed")
11+
12+
13+
from distributed import Client, LocalCluster
14+
15+
16+
@pytest.mark.skipif(
17+
Version(distributed.__version__) < Version("2024.6.0"),
18+
reason="distributed < 2024.6 has a wrong assertion",
19+
# https://github.com/dask/distributed/pull/8667
20+
)
21+
@pytest.mark.skipif(
22+
Version(distributed.__version__) < Version("0.13"),
23+
reason="geopandas < 0.13 does not implement sorting geometries",
24+
)
25+
def test_spatial_shuffle(naturalearth_cities):
26+
df_points = geopandas.read_file(naturalearth_cities)
27+
28+
with LocalCluster(n_workers=1) as cluster:
29+
with Client(cluster):
30+
ddf_points = dask_geopandas.from_geopandas(df_points, npartitions=4)
31+
32+
ddf_result = ddf_points.spatial_shuffle(
33+
by="hilbert", calculate_partitions=False
34+
)
35+
result = ddf_result.compute()
36+
37+
expected = df_points.sort_values("geometry").reset_index(drop=True)
38+
assert_geodataframe_equal(result.reset_index(drop=True), expected)

0 commit comments

Comments
 (0)