Skip to content

Commit 8571ffe

Browse files
ENH: use total_bounds of spatial_partitions if available for Hilbert/Morton distance (#161)
1 parent 3057bfb commit 8571ffe

File tree

3 files changed

+35
-7
lines changed

3 files changed

+35
-7
lines changed

dask_geopandas/core.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,10 @@ def hilbert_distance(self, total_bounds=None, level=16):
348348
total_bounds : 4-element array, optional
349349
The spatial extent in which the curve is constructed (used to
350350
rescale the geometry midpoints). By default, the total bounds
351-
of the full dask GeoDataFrame will be computed. If known, you
352-
can pass the total bounds to avoid this extra computation.
351+
of the full dask GeoDataFrame will be computed (from the spatial
352+
partitions, if available, otherwise computed from the full
353+
dataframe). If known, you can pass the total bounds to avoid this
354+
extra computation.
353355
level : int (1 - 16), default 16
354356
Determines the precision of the curve (points on the curve will
355357
have coordinates in the range [0, 2^level - 1]).
@@ -362,7 +364,10 @@ def hilbert_distance(self, total_bounds=None, level=16):
362364
"""
363365
# Compute total bounds of all partitions rather than each partition
364366
if total_bounds is None:
365-
total_bounds = self.total_bounds
367+
if self.spatial_partitions is not None:
368+
total_bounds = self.spatial_partitions.total_bounds
369+
else:
370+
total_bounds = self.total_bounds
366371

367372
# Calculate hilbert distances for each partition
368373
distances = self.map_partitions(
@@ -396,20 +401,25 @@ def morton_distance(self, total_bounds=None, level=16):
396401
total_bounds : 4-element array, optional
397402
The spatial extent in which the curve is constructed (used to
398403
rescale the geometry midpoints). By default, the total bounds
399-
of the full dask GeoDataFrame will be computed. If known, you
400-
can pass the total bounds to avoid this extra computation.
404+
of the full dask GeoDataFrame will be computed (from the spatial
405+
partitions, if available, otherwise computed from the full
406+
dataframe). If known, you can pass the total bounds to avoid this
407+
extra computation.
401408
level : int (1 - 16), default 16
402409
Determines the precision of the Morton curve.
403410
404411
Returns
405412
-------
406413
dask.Series
407414
Series containing distances along the Morton curve
408-
"""
409415
416+
"""
410417
# Compute total bounds of all partitions rather than each partition
411418
if total_bounds is None:
412-
total_bounds = self.total_bounds
419+
if self.spatial_partitions is not None:
420+
total_bounds = self.spatial_partitions.total_bounds
421+
else:
422+
total_bounds = self.total_bounds
413423

414424
# Calculate Morton distances for each partition
415425
distances = self.map_partitions(

tests/test_hilbert_distance.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,15 @@ def test_specified_total_bounds(geoseries_polygons):
108108
assert_series_equal(result.compute(), expected.compute())
109109

110110

111+
def test_total_bounds_from_partitions(geoseries_polygons):
112+
ddf = from_geopandas(geoseries_polygons, npartitions=2)
113+
expected = ddf.hilbert_distance().compute()
114+
115+
ddf.calculate_spatial_partitions()
116+
result = ddf.hilbert_distance().compute()
117+
assert_series_equal(result, expected)
118+
119+
111120
def test_world():
112121
# world without Fiji
113122
hilbert_distance_dask(

tests/test_morton_distance.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ def test_specified_total_bounds(geoseries_polygons):
7777
assert_series_equal(result.compute(), expected.compute())
7878

7979

80+
def test_total_bounds_from_partitions(geoseries_polygons):
81+
ddf = from_geopandas(geoseries_polygons, npartitions=2)
82+
expected = ddf.morton_distance().compute()
83+
84+
ddf.calculate_spatial_partitions()
85+
result = ddf.morton_distance().compute()
86+
assert_series_equal(result, expected)
87+
88+
8089
def test_world():
8190
# world without Fiji
8291
morton_distance_dask(

0 commit comments

Comments
 (0)