Skip to content

Commit 0bf4f99

Browse files
committed
parquet working
1 parent 6be29e8 commit 0bf4f99

File tree

1 file changed

+54
-16
lines changed

1 file changed

+54
-16
lines changed

dask_geopandas/io/parquet.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from functools import partial
1+
from functools import cached_property, partial
22

33
import dask.dataframe as dd
4-
54
import geopandas
65

76
from .arrow import (
@@ -44,19 +43,34 @@ def _get_partition_bounds_parquet(part, fs):
4443
return _get_partition_bounds(pq_metadata.metadata)
4544

4645

46+
def _convert_to_list(column) -> list | None:
47+
if column is None or isinstance(column, list):
48+
pass
49+
elif isinstance(column, tuple):
50+
column = list(column)
51+
elif hasattr(column, "dtype"):
52+
column = column.tolist()
53+
else:
54+
column = [column]
55+
return column
56+
57+
4758
class GeoArrowEngine(GeoDatasetEngine, DaskArrowDatasetEngine):
4859
"""
4960
Engine for reading geospatial Parquet datasets. Subclasses dask's
5061
ArrowEngine for Parquet, but overriding some methods to ensure we
5162
correctly read/write GeoDataFrames.
5263
5364
"""
54-
55-
@classmethod
56-
def read_metadata(cls, fs, paths, **kwargs):
57-
meta, stats, parts, index = super().read_metadata(fs, paths, **kwargs)
58-
59-
gather_spatial_partitions = kwargs.pop("gather_spatial_partitions", True)
65+
@cached_property
66+
def _meta(self):
67+
meta = super()._meta
68+
gather_spatial_partitions = self._dataset_info.get(
69+
"gather_spatial_partitions", True
70+
)
71+
fs = self._dataset_info["fs"]
72+
parts = self._dataset_info["parts"]
73+
breakpoint()
6074

6175
if gather_spatial_partitions:
6276
regions = geopandas.GeoSeries(
@@ -67,7 +81,24 @@ def read_metadata(cls, fs, paths, **kwargs):
6781
# a bit hacky, but this allows us to get this passed through
6882
meta.attrs["spatial_partitions"] = regions
6983

70-
return (meta, stats, parts, index)
84+
return meta
85+
86+
# @classmethod
87+
# def read_metadata(cls, fs, paths, **kwargs):
88+
# meta, stats, parts, index = super().read_metadata(fs, paths, **kwargs)
89+
90+
# gather_spatial_partitions = kwargs.pop("gather_spatial_partitions", True)
91+
92+
# if gather_spatial_partitions:
93+
# regions = geopandas.GeoSeries(
94+
# [_get_partition_bounds_parquet(part, fs) for part in parts],
95+
# crs=meta.crs,
96+
# )
97+
# if regions.notna().all():
98+
# # a bit hacky, but this allows us to get this passed through
99+
# meta.attrs["spatial_partitions"] = regions
100+
101+
# return (meta, stats, parts, index)
71102

72103
@classmethod
73104
def _update_meta(cls, meta, schema):
@@ -77,13 +108,8 @@ def _update_meta(cls, meta, schema):
77108
return _update_meta_to_geodataframe(meta, schema.metadata)
78109

79110
@classmethod
80-
def _create_dd_meta(cls, dataset_info, use_nullable_dtypes=False):
81-
"""Overriding private method for dask >= 2021.10.0"""
82-
if DASK_2022_12_0_PLUS and not DASK_2023_04_0:
83-
meta = super()._create_dd_meta(dataset_info, use_nullable_dtypes)
84-
else:
85-
meta = super()._create_dd_meta(dataset_info)
86-
111+
def _create_dd_meta(cls, dataset_info):
112+
meta = super()._create_dd_meta(dataset_info)
87113
schema = dataset_info["schema"]
88114
if not schema.names and not schema.metadata:
89115
if len(list(dataset_info["ds"].get_fragments())) == 0:
@@ -92,6 +118,18 @@ def _create_dd_meta(cls, dataset_info, use_nullable_dtypes=False):
92118
"to read it as an empty DataFrame"
93119
)
94120
meta = cls._update_meta(meta, schema)
121+
122+
if dataset_info["kwargs"].get("gather_spatial_partitions", True):
123+
fs = dataset_info["fs"]
124+
parts, _, _ = cls._construct_collection_plan(dataset_info)
125+
regions = geopandas.GeoSeries(
126+
[_get_partition_bounds_parquet(part, fs) for part in parts],
127+
crs=meta.crs,
128+
)
129+
if regions.notna().all():
130+
# a bit hacky, but this allows us to get this passed through
131+
meta.attrs["spatial_partitions"] = regions
132+
95133
return meta
96134

97135

0 commit comments

Comments
 (0)