Skip to content

Commit f6390c6

Browse files
authored
Merge pull request #93 from csiro-coasts/get_all_geometry_names
Add `Convention.select_variables()`
2 parents 5c5f161 + 5f3b371 commit f6390c6

File tree

6 files changed

+96
-27
lines changed

6 files changed

+96
-27
lines changed

docs/releases/development.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ Next release (in development)
1313
Add :attr:`.Convention.time_coordinate` and :attr:`.Convention.depth_coordinate`,
1414
deprecate :meth:`.Convention.get_times()` and :meth:`.Convention.get_depths()`
1515
(:pr:`92`).
16+
* Add :meth:`.Convention.select_variables` (:pr:`93`).

src/emsarray/conventions/_base.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import warnings
88
from functools import cached_property
99
from typing import (
10-
TYPE_CHECKING, Any, Callable, Dict, FrozenSet, Generic, Hashable, List,
11-
Optional, Tuple, TypeVar, Union, cast
10+
TYPE_CHECKING, Any, Callable, Dict, FrozenSet, Generic, Hashable, Iterable,
11+
List, Optional, Tuple, TypeVar, Union, cast
1212
)
1313

1414
import numpy as np
@@ -1257,14 +1257,63 @@ def select_point(self, point: Point) -> xr.Dataset:
12571257
return self.select_index(index.index)
12581258

12591259
@abc.abstractmethod
1260+
def get_all_geometry_names(self) -> List[Hashable]:
1261+
"""
1262+
Return a list of the names of all geometry variables used by this convention.
1263+
1264+
See Also
1265+
--------
1266+
drop_geometry
1267+
select_variables
1268+
"""
1269+
pass
1270+
12601271
def drop_geometry(self) -> xr.Dataset:
12611272
"""
12621273
Return a new :class:`xarray.Dataset`
12631274
with all geometry variables dropped.
12641275
Useful when significantly transforming the dataset,
12651276
such as :mod:`extracting point data <emsarray.operations.point_extraction>`.
1277+
1278+
See Also
1279+
--------
1280+
get_all_geometry_names
1281+
select_variables
12661282
"""
1267-
pass
1283+
return self.dataset.drop_vars(self.get_all_geometry_names())
1284+
1285+
def select_variables(self, variables: Iterable[Hashable]) -> xr.Dataset:
1286+
"""Select only a subset of the variables in this dataset, dropping all others.
1287+
1288+
This will keep all coordinate variables and all geometry variables.
1289+
1290+
Parameters
1291+
----------
1292+
variables : iterable of Hashable
1293+
The names of all data variables to select.
1294+
1295+
Returns
1296+
-------
1297+
xarray.DataArray
1298+
A new dataset with the same geometry and coordinates,
1299+
but only the selected data variables.
1300+
1301+
See also
1302+
--------
1303+
get_all_geometry_names
1304+
drop_geometry
1305+
"""
1306+
all_vars = set(self.dataset.variables.keys())
1307+
keep_vars = {
1308+
*variables,
1309+
*self.get_all_geometry_names(),
1310+
*self.get_all_depth_names(),
1311+
}
1312+
try:
1313+
keep_vars.add(self.get_time_name())
1314+
except NoSuchCoordinateError:
1315+
pass
1316+
return self.dataset.drop_vars(all_vars - keep_vars)
12681317

12691318
@abc.abstractmethod
12701319
def make_clip_mask(

src/emsarray/conventions/arakawa_c.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import enum
1212
import logging
1313
from functools import cached_property
14-
from typing import Dict, Hashable, Optional, Tuple, cast
14+
from typing import Dict, Hashable, List, Optional, Tuple, cast
1515

1616
import numpy as np
1717
import xarray as xr
@@ -307,8 +307,8 @@ def selector_for_index(self, index: ArakawaCIndex) -> Dict[Hashable, int]:
307307
topology = self._topology_for_grid_kind[kind]
308308
return {topology.j_dimension: j, topology.i_dimension: i}
309309

310-
def drop_geometry(self) -> xr.Dataset:
311-
variables = [
310+
def get_all_geometry_names(self) -> List[Hashable]:
311+
return [
312312
self.face.longitude.name,
313313
self.face.latitude.name,
314314
self.node.longitude.name,
@@ -318,7 +318,6 @@ def drop_geometry(self) -> xr.Dataset:
318318
self.back.longitude.name,
319319
self.back.latitude.name,
320320
]
321-
return self.dataset.drop_vars(variables)
322321

323322
def make_linear(self, data_array: xr.DataArray) -> xr.DataArray:
324323
kind, size = self.get_grid_kind_and_size(data_array)

src/emsarray/conventions/grid.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from contextlib import suppress
1212
from functools import cached_property
1313
from typing import (
14-
Dict, Generic, Hashable, Optional, Tuple, Type, TypeVar, cast
14+
Dict, Generic, Hashable, List, Optional, Tuple, Type, TypeVar, cast
1515
)
1616

1717
import numpy as np
@@ -264,7 +264,10 @@ def unravel_index(
264264
def ravel_index(self, indices: CFGridIndex) -> int:
265265
return int(np.ravel_multi_index(indices, self.topology.shape))
266266

267-
def get_grid_kind_and_size(self, data_array: xr.DataArray) -> Tuple[CFGridKind, int]:
267+
def get_grid_kind_and_size(
268+
self,
269+
data_array: xr.DataArray,
270+
) -> Tuple[CFGridKind, int]:
268271
expected = {self.topology.y_dimension, self.topology.x_dimension}
269272
dims = set(data_array.dims)
270273
if dims.issuperset(expected):
@@ -277,13 +280,27 @@ def selector_for_index(self, index: CFGridIndex) -> Dict[Hashable, int]:
277280
y, x = index
278281
return {self.topology.y_dimension: y, self.topology.x_dimension: x}
279282

280-
def drop_geometry(self) -> xr.Dataset:
281-
dataset = self.dataset.drop_vars([
283+
def get_all_geometry_names(self) -> List[Hashable]:
284+
# Grid datasets contain latitude and longitude variables
285+
# plus optional bounds variables.
286+
names = [
282287
self.topology.longitude_name,
283288
self.topology.latitude_name,
284-
])
285-
dataset.attrs.pop('Conventions', None)
289+
]
286290

291+
bounds_names: List[Optional[Hashable]] = [
292+
self.topology.longitude.attrs.get('bounds', None),
293+
self.topology.latitude.attrs.get('bounds', None),
294+
]
295+
for bounds_name in bounds_names:
296+
if bounds_name is not None and bounds_name in self.dataset.variables:
297+
names.append(bounds_name)
298+
299+
return names
300+
301+
def drop_geometry(self) -> xr.Dataset:
302+
dataset = super().drop_geometry()
303+
dataset.attrs.pop('Conventions', None)
287304
return dataset
288305

289306
def make_linear(self, data_array: xr.DataArray) -> xr.DataArray:

src/emsarray/conventions/ugrid.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,34 +1315,34 @@ def integer_indices(data_array: xr.DataArray) -> np.ndarray:
13151315
new_dataset = xr.open_mfdataset(mfdataset_paths, lock=False)
13161316
return utils.dataset_like(dataset, new_dataset)
13171317

1318-
def drop_geometry(self) -> xr.Dataset:
1319-
dataset = self.dataset
1318+
def get_all_geometry_names(self) -> List[Hashable]:
13201319
topology = self.topology
13211320

1322-
geometry_variables = [
1321+
names = [
13231322
topology.mesh_variable.name,
13241323
topology.face_node_connectivity.name,
13251324
topology.node_x.name,
13261325
topology.node_y.name,
13271326
]
13281327
if topology.has_valid_face_edge_connectivity:
1329-
geometry_variables.append(topology.face_edge_connectivity.name)
1328+
names.append(topology.face_edge_connectivity.name)
13301329
if topology.has_valid_face_face_connectivity:
1331-
geometry_variables.append(topology.face_face_connectivity.name)
1330+
names.append(topology.face_face_connectivity.name)
13321331
if topology.has_valid_edge_node_connectivity:
1333-
geometry_variables.append(topology.edge_node_connectivity.name)
1332+
names.append(topology.edge_node_connectivity.name)
13341333
if topology.has_valid_edge_face_connectivity:
1335-
geometry_variables.append(topology.edge_face_connectivity.name)
1334+
names.append(topology.edge_face_connectivity.name)
13361335
if topology.edge_x is not None:
1337-
geometry_variables.append(topology.edge_x.name)
1336+
names.append(topology.edge_x.name)
13381337
if topology.edge_y is not None:
1339-
geometry_variables.append(topology.edge_y.name)
1338+
names.append(topology.edge_y.name)
13401339
if topology.face_x is not None:
1341-
geometry_variables.append(topology.face_x.name)
1340+
names.append(topology.face_x.name)
13421341
if topology.face_y is not None:
1343-
geometry_variables.append(topology.face_y.name)
1342+
names.append(topology.face_y.name)
1343+
return names
13441344

1345-
dataset = dataset.drop_vars(geometry_variables)
1345+
def drop_geometry(self) -> xr.Dataset:
1346+
dataset = super().drop_geometry()
13461347
dataset.attrs.pop('Conventions', None)
1347-
13481348
return dataset

tests/conventions/test_base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import enum
55
import pathlib
66
from functools import cached_property
7-
from typing import Dict, Hashable, Optional, Tuple
7+
from typing import Dict, Hashable, List, Optional, Tuple
88

99
import matplotlib.pyplot as plt
1010
import numpy as np
@@ -54,6 +54,9 @@ def get_grid_kind_and_size(self, data_array: xr.DataArray) -> Tuple[SimpleGridKi
5454
return (SimpleGridKind.face, int(np.prod(self.shape)))
5555
raise ValueError("Invalid dimensions")
5656

57+
def get_all_geometry_names(self) -> List[Hashable]:
58+
return ['x', 'y']
59+
5760
def unravel_index(
5861
self,
5962
index: int,

0 commit comments

Comments
 (0)