|
| 1 | +from intake import Schema |
| 2 | +from intake.source.derived import GenericTransform |
| 3 | + |
| 4 | + |
| 5 | +class XArrayTransform(GenericTransform): |
| 6 | + """Transform where the input and output are both xarray objects. |
| 7 | + You must supply ``transform`` and any ``transform_kwargs``. |
| 8 | + """ |
| 9 | + |
| 10 | + input_container = "xarray" |
| 11 | + container = "xarray" |
| 12 | + optional_params = {} |
| 13 | + _ds = None |
| 14 | + |
| 15 | + def to_dask(self): |
| 16 | + if self._ds is None: |
| 17 | + self._pick() |
| 18 | + self._ds = self._transform( |
| 19 | + self._source.to_dask(), **self._params["transform_kwargs"] |
| 20 | + ) |
| 21 | + return self._ds |
| 22 | + |
| 23 | + def _get_schema(self): |
| 24 | + """load metadata only if needed""" |
| 25 | + self.to_dask() |
| 26 | + return Schema( |
| 27 | + datashape=None, |
| 28 | + dtype=None, |
| 29 | + shape=None, |
| 30 | + npartitions=None, |
| 31 | + extra_metadata=self._ds.extra_metadata, |
| 32 | + ) |
| 33 | + |
| 34 | + def read(self): |
| 35 | + return self.to_dask().compute() |
| 36 | + |
| 37 | + |
| 38 | +class Sel(XArrayTransform): |
| 39 | + """Simple array transform to subsample an xarray object using |
| 40 | + the sel method. |
| 41 | + Note that you could use XArrayTransform directly, by writing a |
| 42 | + function to choose the subsample instead of a method as here. |
| 43 | + """ |
| 44 | + |
| 45 | + input_container = "xarray" |
| 46 | + container = "xarray" |
| 47 | + required_params = ["indexers"] |
| 48 | + |
| 49 | + def __init__(self, indexers, **kwargs): |
| 50 | + """ |
| 51 | + indexers: dict (stord as str) which is passed to xarray.Dataset.sel |
| 52 | + """ |
| 53 | + # this class wants required "indexers", but XArrayTransform |
| 54 | + # uses "transform_kwargs", which we don't need since we use a method for the |
| 55 | + # transform |
| 56 | + kwargs.update( |
| 57 | + transform=self.sel, |
| 58 | + indexers=indexers, |
| 59 | + transform_kwargs={}, |
| 60 | + ) |
| 61 | + super().__init__(**kwargs) |
| 62 | + |
| 63 | + def sel(self, ds): |
| 64 | + return ds.sel(eval(self._params["indexers"])) |
0 commit comments