Skip to content

Commit eb1d713

Browse files
authored
fix for to_polygons when using processes instead of threads in dask (scverse#756)
vectorize fix
1 parent d3cdf69 commit eb1d713

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

src/spatialdata/_core/operations/vectorize.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,13 @@ def to_polygons(data: SpatialElement, buffer_resolution: int | None = None) -> G
160160
"""
161161
Convert a set of geometries (2D labels, 2D shapes) to approximated 2D polygons/multypolygons.
162162
163+
For optimal performance when converting rasters (:class:`xarray.DataArray` or :class:`datatree.DataTree`)
164+
to polygons, it is recommended to configure `Dask` to use 'processes' rather than 'threads'.
165+
For example, you can set this configuration with:
166+
167+
>>> import dask
168+
>>> dask.config.set(scheduler='processes')
169+
163170
Parameters
164171
----------
165172
data
@@ -194,23 +201,22 @@ def _(
194201
else:
195202
element_single_scale = element
196203

197-
gdf_chunks = []
198204
chunk_sizes = element_single_scale.data.chunks
199205

200-
def _vectorize_chunk(chunk: np.ndarray, yoff: int, xoff: int) -> None: # type: ignore[type-arg]
206+
def _vectorize_chunk(chunk: np.ndarray, yoff: int, xoff: int) -> GeoDataFrame: # type: ignore[type-arg]
201207
gdf = _vectorize_mask(chunk)
202208
gdf["chunk-location"] = f"({yoff}, {xoff})"
203209
gdf.geometry = gdf.translate(xoff, yoff)
204-
gdf_chunks.append(gdf)
210+
return gdf
205211

206212
tasks = [
207213
dask.delayed(_vectorize_chunk)(chunk, sum(chunk_sizes[0][:iy]), sum(chunk_sizes[1][:ix]))
208214
for iy, row in enumerate(element_single_scale.data.to_delayed())
209215
for ix, chunk in enumerate(row)
210216
]
211-
dask.compute(tasks)
212217

213-
gdf = pd.concat(gdf_chunks)
218+
results = dask.compute(*tasks)
219+
gdf = pd.concat(results)
214220
gdf = GeoDataFrame([_dissolve_on_overlaps(*item) for item in gdf.groupby("label")], columns=["label", "geometry"])
215221
gdf.index = gdf["label"]
216222

0 commit comments

Comments
 (0)