Skip to content

Commit 5de0975

Browse files
authored
Merge pull request #90 from csiro-coasts/point-extraction-errors
Add `missing_points` argument to `extract_points`/`extract_dataframe`
2 parents 1bdfd77 + fd1eb83 commit 5de0975

File tree

5 files changed

+188
-45
lines changed

5 files changed

+188
-45
lines changed

docs/releases/development.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,10 @@
22
Next release (in development)
33
=============================
44

5-
* ...
5+
* Add `missing_points` parameter
6+
to :func:`emsarray.operations.point_extraction.extract_points`
7+
and :func:`emsarray.operations.point_extraction.extract_dataframe`.
8+
Callers can now choose whether missing points raise an exception,
9+
are dropped from the returned dataset,
10+
or filled with a sensible fill value
11+
(:pr:`90`).

src/emsarray/cli/commands/extract_points.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,16 @@ def add_arguments(self, parser: argparse.ArgumentParser) -> None:
4949
metavar=('DIM'),
5050
default="point",
5151
help=(
52-
"Name of the new dimension to index the point data"
52+
"Name of the new dimension to index the point data."
53+
))
54+
55+
parser.add_argument(
56+
"--missing-points",
57+
choices=['error', 'drop', 'fill'],
58+
default='error',
59+
help=(
60+
"What to do when a point does not intersect the dataset geometry. "
61+
"Defaults to 'error'."
5362
))
5463

5564
def handle(self, options: argparse.Namespace) -> None:
@@ -60,7 +69,8 @@ def handle(self, options: argparse.Namespace) -> None:
6069
try:
6170
point_data = point_extraction.extract_dataframe(
6271
dataset, dataframe, options.coordinate_columns,
63-
point_dimension=options.point_dimension)
72+
point_dimension=options.point_dimension,
73+
missing_points=options.missing_points)
6474
except point_extraction.NonIntersectingPoints as err:
6575
rows = dataframe.iloc[err.indices]
6676
raise CommandException(

src/emsarray/operations/point_extraction.py

Lines changed: 62 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,16 @@
1111
and returns a new dataset with out any associated geometry.
1212
This is useful if you want to add your own metadata to the subset dataset.
1313
14-
If any of the supplied points does not intersect the dataset geometry,
15-
a :exc:`.NonIntersectingPoints` exception is raised.
16-
This will include the indices of the points that do not intersect.
17-
1814
:ref:`emsarray extract-points` is a command line interface to :func:`.extract_dataframe`.
1915
"""
2016
import dataclasses
21-
from typing import Hashable, List, Tuple
17+
from typing import Any, Hashable, List, Literal, Tuple
2218

2319
import numpy as np
2420
import pandas as pd
21+
import shapely
2522
import xarray as xr
26-
from shapely.geometry import Point
23+
import xarray.core.dtypes as xrdtypes
2724

2825
from emsarray.conventions import Convention
2926

@@ -38,7 +35,7 @@ class NonIntersectingPoints(ValueError):
3835
indices: np.ndarray
3936

4037
#: The non-intersecting points
41-
points: List[Point]
38+
points: List[shapely.Point]
4239

4340
def __post_init__(self) -> None:
4441
super().__init__(f"{self.points[0].wkt} does not intersect the dataset geometry")
@@ -49,24 +46,19 @@ def _dataframe_to_dataset(
4946
*,
5047
dimension_name: Hashable,
5148
) -> xr.Dataset:
52-
"""
53-
Convert a pandas DataFrame to an xarray Dataset.
54-
pandas adds an 'index' coordinate that numbers the 'index' dimension.
55-
We don't need the coordinate, and the dimension needs to be renamed.
56-
"""
57-
index_name = dataframe.index.name or 'index'
49+
"""Convert a pandas DataFrame to an xarray Dataset."""
50+
dataframe = dataframe.copy()
51+
dataframe.index.name = dimension_name
5852
dataset = dataframe.to_xarray()
59-
dataset = dataset.drop_vars(index_name)
60-
if dimension_name != index_name:
61-
dataset = dataset.rename_dims({index_name: dimension_name})
6253
return dataset
6354

6455

6556
def extract_points(
6657
dataset: xr.Dataset,
67-
points: List[Point],
58+
points: List[shapely.Point],
6859
*,
6960
point_dimension: Hashable = 'point',
61+
missing_points: Literal['error', 'drop'] = 'error',
7062
) -> xr.Dataset:
7163
"""
7264
Drop all data except for cells that intersect the given points.
@@ -80,19 +72,27 @@ def extract_points(
8072
----------
8173
dataset : xarray.Dataset
8274
The dataset to extract point data from.
83-
points : list of :class:`Point`
75+
points : list of :class:`shapely.Point`
8476
The points to select.
8577
point_dimension : Hashable, optional
8678
The name of the new dimension to index points along.
8779
Defaults to ``"point"``.
80+
errors : {'raise', 'drop'}, default 'raise'
81+
How to handle points which do not intersect the dataset.
82+
83+
- If 'raise', a :exc:`NonIntersectingPoints` is raised.
84+
- If 'drop', the points are dropped from the returned dataset.
8885
8986
Returns
9087
-------
9188
xarray.Dataset
9289
A subset of the input dataset that only contains data at the given points.
93-
The dataset will only contain the values, without any coordinate information.
90+
The dataset will only contain the values, without any geometry coordinates.
91+
The `point_dimension` dimension will have a coordinate with the same name
92+
whose values match the indices of the `points` array.
93+
This is useful when `errors` is 'drop' to find out which points were dropped.
9494
95-
See also
95+
See Also
9696
--------
9797
:func:`extract_dataframe`
9898
"""
@@ -101,23 +101,28 @@ def extract_points(
101101
# Find the indexer for each given point
102102
indexes = np.array([convention.get_index_for_point(point) for point in points])
103103

104-
# TODO It would be nicer if out-of-bounds points were represented in the
105-
# output by masked values, rather than raising an error.
106-
out_of_bounds = np.flatnonzero(np.equal(indexes, None)) # type: ignore
107-
if len(out_of_bounds):
108-
raise NonIntersectingPoints(
109-
indices=out_of_bounds,
110-
points=[points[i] for i in out_of_bounds])
104+
if missing_points == 'error':
105+
out_of_bounds = np.flatnonzero(np.equal(indexes, None)) # type: ignore
106+
if len(out_of_bounds):
107+
raise NonIntersectingPoints(
108+
indices=out_of_bounds,
109+
points=[points[i] for i in out_of_bounds])
111110

112111
# Make a DataFrame out of all point indexers
113112
selector_df = pd.DataFrame([
114113
convention.selector_for_index(index.index)
115114
for index in indexes
116115
if index is not None])
116+
point_indexes = [i for i, index in enumerate(indexes) if index is not None]
117117

118118
# Subset the dataset to the points
119+
point_ds = convention.drop_geometry()
119120
selector_ds = _dataframe_to_dataset(selector_df, dimension_name=point_dimension)
120-
return convention.drop_geometry().isel(selector_ds)
121+
point_ds = point_ds.isel(selector_ds)
122+
point_ds = point_ds.assign_coords({
123+
point_dimension: ([point_dimension], point_indexes),
124+
})
125+
return point_ds
121126

122127

123128
def extract_dataframe(
@@ -126,6 +131,8 @@ def extract_dataframe(
126131
coordinate_columns: Tuple[str, str],
127132
*,
128133
point_dimension: Hashable = 'point',
134+
missing_points: Literal['error', 'drop', 'fill'] = 'error',
135+
fill_value: Any = xrdtypes.NA,
129136
) -> xr.Dataset:
130137
"""
131138
Extract the points listed in a pandas :class:`~pandas.DataFrame`,
@@ -143,12 +150,27 @@ def extract_dataframe(
143150
point_dimension : Hashable, optional
144151
The name of the new dimension to create in the dataset.
145152
Optional, defaults to "point".
153+
missing_points : {'error', 'drop', 'fill'}, default 'error'
154+
How to handle points that do not intersect the dataset geometry:
155+
156+
- 'error' will raise a :class:`.NonIntersectingPoints` exception.
157+
- 'drop' will drop those points from the dataset.
158+
- 'fill' will include those points but all data variables
159+
will be filled with an appropriate fill value
160+
such as :data:`numpy.nan` for float values.
161+
fill_value
162+
Passed to :meth:`xarray.Dataset.merge` when `missing_points` is 'fill'.
163+
See the documentation for that method for all options.
164+
Defaults to a sensible fill value for each variables dtype.
146165
147166
Returns
148167
-------
149168
xarray.Dataset
150169
A new dataset that only contains data at the given points,
151170
plus any new columns present in the dataframe.
171+
The `point_dimension` dimension will have a coordinate with the same name
172+
whose values match the row numbers of the dataframe.
173+
This is useful when `missing_points` is "drop" to find out which points were dropped.
152174
153175
Example
154176
-------
@@ -176,9 +198,10 @@ def extract_dataframe(
176198
Coordinates:
177199
zc (k) float32 ...
178200
* time (time) datetime64[ns] 2022-05-11T14:00:00
201+
* point (point) int64 0 1 2
179202
lon (point) float64 152.8 152.7 153.5
180203
lat (point) float64 -24.96 -24.59 -25.49
181-
Dimensions without coordinates: k, point
204+
Dimensions without coordinates: k
182205
Data variables:
183206
botz (point) float32 ...
184207
eta (time, point) float32 ...
@@ -191,24 +214,26 @@ def extract_dataframe(
191214
lon_coord, lat_coord = coordinate_columns
192215

193216
# Extract the points from the dataset
194-
points = [
195-
Point(row[lon_coord], row[lat_coord])
196-
for i, row in dataframe.iterrows()]
197-
point_dataset = extract_points(dataset, points, point_dimension=point_dimension)
217+
points = shapely.points(np.c_[dataframe[lon_coord], dataframe[lat_coord]])
218+
219+
point_dataset = extract_points(
220+
dataset, points, point_dimension=point_dimension,
221+
missing_points='error' if missing_points == 'error' else 'drop')
222+
coord_dataset = _dataframe_to_dataset(dataframe, dimension_name=point_dimension)
198223

199224
# Merge in the dataframe
200-
point_dataset = point_dataset.merge(_dataframe_to_dataset(
201-
dataframe, dimension_name=point_dimension))
225+
join: Literal['outer', 'inner'] = 'outer' if missing_points == 'fill' else 'inner'
226+
point_dataset = point_dataset.merge(coord_dataset, join=join, fill_value=fill_value)
202227
point_dataset = point_dataset.set_coords(coordinate_columns)
203228

204229
# Add CF attributes to the new coordinate variables
205230
point_dataset[lon_coord].attrs.update({
206-
"long_name": "longitude",
231+
"long_name": "Longitude",
207232
"units": "degrees_east",
208233
"standard_name": "longitude",
209234
})
210235
point_dataset[lat_coord].attrs.update({
211-
"long_name": "latitude",
236+
"long_name": "Latitude",
212237
"units": "degrees_north",
213238
"standard_name": "latitude",
214239
})

tests/operations/point_extraction/test_extract_dataframe.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
import pandas as pd
5+
import pytest
56
import xarray as xr
67
from numpy.testing import assert_equal
78
from shapely.geometry import Point
@@ -36,12 +37,12 @@ def test_extract_dataframe(
3637

3738
# The new point coordinate variables should have the relevant CF attributes
3839
assert point_dataset['lon'].attrs == {
39-
"long_name": "longitude",
40+
"long_name": "Longitude",
4041
"units": "degrees_east",
4142
"standard_name": "longitude",
4243
}
4344
assert point_dataset['lat'].attrs == {
44-
"long_name": "latitude",
45+
"long_name": "Latitude",
4546
"units": "degrees_north",
4647
"standard_name": "latitude",
4748
}
@@ -55,3 +56,79 @@ def test_extract_dataframe(
5556
in_dataset.ems.select_point(Point(row['lon'], row['lat']))['values'].values
5657
for i, row in points_df.iterrows()
5758
])
59+
60+
61+
def test_extract_dataframe_point_dimension(
62+
datasets: pathlib.Path,
63+
) -> None:
64+
points_df = pd.DataFrame({
65+
'name': ['a', 'b', 'c', 'd'],
66+
'lon': [0, 1, 2, 3],
67+
'lat': [0, 0, 0, 0],
68+
})
69+
in_dataset = xr.open_dataset(datasets / 'ugrid_mesh2d.nc')
70+
point_dataset = point_extraction.extract_dataframe(
71+
in_dataset, points_df, ('lon', 'lat'), point_dimension='foo')
72+
assert point_dataset.dims['foo'] == 4
73+
assert point_dataset['foo'].dims == ('foo',)
74+
assert_equal(point_dataset['foo'].values, [0, 1, 2, 3])
75+
assert point_dataset['values'].dims == ('foo',)
76+
77+
78+
def test_extract_points_missing_point_error(
79+
datasets: pathlib.Path,
80+
) -> None:
81+
points_df = pd.DataFrame({
82+
'name': ['a', 'b', 'c', 'd'],
83+
'lon': [0, 10, 1, 20],
84+
'lat': [0, 0, 0, 0],
85+
})
86+
in_dataset = xr.open_dataset(datasets / 'ugrid_mesh2d.nc')
87+
with pytest.raises(point_extraction.NonIntersectingPoints) as exc_info:
88+
point_extraction.extract_dataframe(in_dataset, points_df, ('lon', 'lat'))
89+
exc: point_extraction.NonIntersectingPoints = exc_info.value
90+
assert_equal(exc.indices, [1, 3])
91+
92+
93+
def test_extract_points_missing_point_drop(
94+
datasets: pathlib.Path,
95+
) -> None:
96+
points_df = pd.DataFrame({
97+
'name': ['a', 'b', 'c', 'd'],
98+
'lon': [0, 10, 1, 20],
99+
'lat': [0, 0, 0, 0],
100+
})
101+
in_dataset = xr.open_dataset(datasets / 'ugrid_mesh2d.nc')
102+
point_dataset = point_extraction.extract_dataframe(
103+
in_dataset, points_df, ('lon', 'lat'), missing_points='drop')
104+
assert point_dataset.dims['point'] == 2
105+
assert 'values' in point_dataset.data_vars
106+
assert_equal(point_dataset['values'].values, [
107+
in_dataset.ems.select_point(Point(0, 0))['values'].values,
108+
in_dataset.ems.select_point(Point(1, 0))['values'].values,
109+
])
110+
assert_equal(point_dataset['point'].values, [0, 2])
111+
assert_equal(point_dataset['name'].values, ['a', 'c'])
112+
113+
114+
def test_extract_points_missing_point_fill(
115+
datasets: pathlib.Path,
116+
) -> None:
117+
points_df = pd.DataFrame({
118+
'name': ['a', 'b', 'c', 'd'],
119+
'lon': [0, 10, 1, 20],
120+
'lat': [0, 0, 0, 0],
121+
})
122+
in_dataset = xr.open_dataset(datasets / 'ugrid_mesh2d.nc')
123+
point_dataset = point_extraction.extract_dataframe(
124+
in_dataset, points_df, ('lon', 'lat'), missing_points='fill')
125+
assert point_dataset.dims['point'] == 4
126+
assert 'values' in point_dataset.data_vars
127+
assert_equal(point_dataset['values'].values, [
128+
in_dataset.ems.select_point(Point(0, 0))['values'].values,
129+
np.nan,
130+
in_dataset.ems.select_point(Point(1, 0))['values'].values,
131+
np.nan,
132+
])
133+
assert_equal(point_dataset['point'].values, [0, 1, 2, 3])
134+
assert_equal(point_dataset['name'].values, ['a', 'b', 'c', 'd'])

0 commit comments

Comments
 (0)