Skip to content

Commit 1bcd5e3

Browse files
committed
Add Convention.select_variables()
A new dataset will be returned that contains only the specified data variables, plus all the geometry, depth, and time variables in the dataset. Useful for dropping all but a few data variables in a dataset without loosing all the geometry information.
1 parent 3b2e789 commit 1bcd5e3

File tree

1 file changed

+37
-2
lines changed

1 file changed

+37
-2
lines changed

src/emsarray/conventions/_base.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import warnings
88
from functools import cached_property
99
from typing import (
10-
TYPE_CHECKING, Any, Callable, Dict, FrozenSet, Generic, Hashable, List,
11-
Optional, Tuple, TypeVar, Union, cast
10+
TYPE_CHECKING, Any, Callable, Dict, FrozenSet, Generic, Hashable, Iterable,
11+
List, Optional, Tuple, TypeVar, Union, cast
1212
)
1313

1414
import numpy as np
@@ -1264,6 +1264,7 @@ def get_all_geometry_names(self) -> List[Hashable]:
12641264
See Also
12651265
--------
12661266
drop_geometry
1267+
select_variables
12671268
"""
12681269
pass
12691270

@@ -1277,9 +1278,43 @@ def drop_geometry(self) -> xr.Dataset:
12771278
See Also
12781279
--------
12791280
get_all_geometry_names
1281+
select_variables
12801282
"""
12811283
return self.dataset.drop_vars(self.get_all_geometry_names())
12821284

1285+
def select_variables(self, variables: Iterable[Hashable]) -> xr.Dataset:
1286+
"""Select only a subset of the variables in this dataset, dropping all others.
1287+
1288+
This will keep all coordinate variables and all geometry variables.
1289+
1290+
Parameters
1291+
----------
1292+
variables : iterable of Hashable
1293+
The names of all data variables to select.
1294+
1295+
Returns
1296+
-------
1297+
xarray.DataArray
1298+
A new dataset with the same geometry and coordinates,
1299+
but only the selected data variables.
1300+
1301+
See also
1302+
--------
1303+
get_all_geometry_names
1304+
drop_geometry
1305+
"""
1306+
all_vars = set(self.dataset.variables.keys())
1307+
keep_vars = {
1308+
*variables,
1309+
*self.get_all_geometry_names(),
1310+
*self.get_all_depth_names(),
1311+
}
1312+
try:
1313+
keep_vars.add(self.get_time_name())
1314+
except NoSuchCoordinateError:
1315+
pass
1316+
return self.dataset.drop_vars(all_vars - keep_vars)
1317+
12831318
@abc.abstractmethod
12841319
def make_clip_mask(
12851320
self,

0 commit comments

Comments
 (0)