Skip to content

Commit d31d1fa

Browse files
authored
Merge pull request #118 from raybellwaves/transform
sel transform
2 parents dc07691 + 1e346e7 commit d31d1fa

File tree

3 files changed

+96
-0
lines changed

3 files changed

+96
-0
lines changed

intake_xarray/derived.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from intake import Schema
2+
from intake.source.derived import GenericTransform
3+
4+
5+
class XArrayTransform(GenericTransform):
6+
"""Transform where the input and output are both xarray objects.
7+
You must supply ``transform`` and any ``transform_kwargs``.
8+
"""
9+
10+
input_container = "xarray"
11+
container = "xarray"
12+
optional_params = {}
13+
_ds = None
14+
15+
def to_dask(self):
16+
if self._ds is None:
17+
self._pick()
18+
self._ds = self._transform(
19+
self._source.to_dask(), **self._params["transform_kwargs"]
20+
)
21+
return self._ds
22+
23+
def _get_schema(self):
24+
"""load metadata only if needed"""
25+
self.to_dask()
26+
return Schema(
27+
datashape=None,
28+
dtype=None,
29+
shape=None,
30+
npartitions=None,
31+
extra_metadata=self._ds.extra_metadata,
32+
)
33+
34+
def read(self):
35+
return self.to_dask().compute()
36+
37+
38+
class Sel(XArrayTransform):
39+
"""Simple array transform to subsample an xarray object using
40+
the sel method.
41+
Note that you could use XArrayTransform directly, by writing a
42+
function to choose the subsample instead of a method as here.
43+
"""
44+
45+
input_container = "xarray"
46+
container = "xarray"
47+
required_params = ["indexers"]
48+
49+
def __init__(self, indexers, **kwargs):
50+
"""
51+
indexers: dict (stord as str) which is passed to xarray.Dataset.sel
52+
"""
53+
# this class wants required "indexers", but XArrayTransform
54+
# uses "transform_kwargs", which we don't need since we use a method for the
55+
# transform
56+
kwargs.update(
57+
transform=self.sel,
58+
indexers=indexers,
59+
transform_kwargs={},
60+
)
61+
super().__init__(**kwargs)
62+
63+
def sel(self, ds):
64+
return ds.sel(eval(self._params["indexers"]))

intake_xarray/tests/data/catalog.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,12 @@ sources:
114114
chunks: {}
115115
auth: null
116116
engine: netcdf4
117+
xarray_source_sel:
118+
description: select subsample of xarray_source entry
119+
driver: intake_xarray.derived.XArrayTransform
120+
args:
121+
targets:
122+
- xarray_source
123+
transform: "intake_xarray.tests.test_derived._sel"
124+
transform_kwargs:
125+
indexers: "dict([('lat', 20)])"
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import os
2+
import pytest
3+
4+
from intake import open_catalog
5+
from xarray.tests import assert_allclose
6+
7+
8+
# Function used in xarray_source_sel entry in catalog.yaml
9+
def _sel(ds, indexers: str):
10+
"""indexers: dict (stored as str) which is passed to xarray.Dataset.sel"""
11+
return ds.sel(eval(indexers))
12+
13+
14+
@pytest.fixture
15+
def catalog():
16+
path = os.path.dirname(__file__)
17+
return open_catalog(os.path.join(path, "data", "catalog.yaml"))
18+
19+
20+
def test_catalog(catalog):
21+
expected = catalog["xarray_source"].read().sel(lat=20)
22+
actual = catalog["xarray_source_sel"].read()
23+
assert_allclose(actual, expected)

0 commit comments

Comments
 (0)