1
- from functools import partial
1
+ from functools import cached_property , partial
2
2
3
3
import dask .dataframe as dd
4
-
5
4
import geopandas
6
5
7
6
from .arrow import (
@@ -44,19 +43,34 @@ def _get_partition_bounds_parquet(part, fs):
44
43
return _get_partition_bounds (pq_metadata .metadata )
45
44
46
45
46
+ def _convert_to_list (column ) -> list | None :
47
+ if column is None or isinstance (column , list ):
48
+ pass
49
+ elif isinstance (column , tuple ):
50
+ column = list (column )
51
+ elif hasattr (column , "dtype" ):
52
+ column = column .tolist ()
53
+ else :
54
+ column = [column ]
55
+ return column
56
+
57
+
47
58
class GeoArrowEngine (GeoDatasetEngine , DaskArrowDatasetEngine ):
48
59
"""
49
60
Engine for reading geospatial Parquet datasets. Subclasses dask's
50
61
ArrowEngine for Parquet, but overriding some methods to ensure we
51
62
correctly read/write GeoDataFrames.
52
63
53
64
"""
54
-
55
- @classmethod
56
- def read_metadata (cls , fs , paths , ** kwargs ):
57
- meta , stats , parts , index = super ().read_metadata (fs , paths , ** kwargs )
58
-
59
- gather_spatial_partitions = kwargs .pop ("gather_spatial_partitions" , True )
65
+ @cached_property
66
+ def _meta (self ):
67
+ meta = super ()._meta
68
+ gather_spatial_partitions = self ._dataset_info .get (
69
+ "gather_spatial_partitions" , True
70
+ )
71
+ fs = self ._dataset_info ["fs" ]
72
+ parts = self ._dataset_info ["parts" ]
73
+ breakpoint ()
60
74
61
75
if gather_spatial_partitions :
62
76
regions = geopandas .GeoSeries (
@@ -67,7 +81,24 @@ def read_metadata(cls, fs, paths, **kwargs):
67
81
# a bit hacky, but this allows us to get this passed through
68
82
meta .attrs ["spatial_partitions" ] = regions
69
83
70
- return (meta , stats , parts , index )
84
+ return meta
85
+
86
+ # @classmethod
87
+ # def read_metadata(cls, fs, paths, **kwargs):
88
+ # meta, stats, parts, index = super().read_metadata(fs, paths, **kwargs)
89
+
90
+ # gather_spatial_partitions = kwargs.pop("gather_spatial_partitions", True)
91
+
92
+ # if gather_spatial_partitions:
93
+ # regions = geopandas.GeoSeries(
94
+ # [_get_partition_bounds_parquet(part, fs) for part in parts],
95
+ # crs=meta.crs,
96
+ # )
97
+ # if regions.notna().all():
98
+ # # a bit hacky, but this allows us to get this passed through
99
+ # meta.attrs["spatial_partitions"] = regions
100
+
101
+ # return (meta, stats, parts, index)
71
102
72
103
@classmethod
73
104
def _update_meta (cls , meta , schema ):
@@ -77,13 +108,8 @@ def _update_meta(cls, meta, schema):
77
108
return _update_meta_to_geodataframe (meta , schema .metadata )
78
109
79
110
@classmethod
80
- def _create_dd_meta (cls , dataset_info , use_nullable_dtypes = False ):
81
- """Overriding private method for dask >= 2021.10.0"""
82
- if DASK_2022_12_0_PLUS and not DASK_2023_04_0 :
83
- meta = super ()._create_dd_meta (dataset_info , use_nullable_dtypes )
84
- else :
85
- meta = super ()._create_dd_meta (dataset_info )
86
-
111
+ def _create_dd_meta (cls , dataset_info ):
112
+ meta = super ()._create_dd_meta (dataset_info )
87
113
schema = dataset_info ["schema" ]
88
114
if not schema .names and not schema .metadata :
89
115
if len (list (dataset_info ["ds" ].get_fragments ())) == 0 :
@@ -92,6 +118,18 @@ def _create_dd_meta(cls, dataset_info, use_nullable_dtypes=False):
92
118
"to read it as an empty DataFrame"
93
119
)
94
120
meta = cls ._update_meta (meta , schema )
121
+
122
+ if dataset_info ["kwargs" ].get ("gather_spatial_partitions" , True ):
123
+ fs = dataset_info ["fs" ]
124
+ parts , _ , _ = cls ._construct_collection_plan (dataset_info )
125
+ regions = geopandas .GeoSeries (
126
+ [_get_partition_bounds_parquet (part , fs ) for part in parts ],
127
+ crs = meta .crs ,
128
+ )
129
+ if regions .notna ().all ():
130
+ # a bit hacky, but this allows us to get this passed through
131
+ meta .attrs ["spatial_partitions" ] = regions
132
+
95
133
return meta
96
134
97
135
0 commit comments