Skip to content

Commit 9c54b9b

Browse files
raspstephanWeatherBenchX authors
authored andcommitted
Add option to add raw values as coordinates in data loader base class (to be used for binning by raw values).
In sparse_parquet and xarray_loaders make all base class arguments kwargs PiperOrigin-RevId: 869579359
1 parent 83904e5 commit 9c54b9b

File tree

3 files changed

+25
-42
lines changed

3 files changed

+25
-42
lines changed

weatherbenchX/data_loaders/base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(
7777
[Mapping[Hashable, xr.DataArray]], Mapping[Hashable, xr.DataArray]
7878
]
7979
] = None,
80+
add_values_to_coords: bool = False,
8081
):
8182
"""Shared initialization for data loaders.
8283
@@ -89,11 +90,15 @@ def __init__(
8990
False.
9091
process_chunk_fn: optional function to be applied to each chunk after
9192
loading but before interpolation, computing, and adding nan mask.
93+
add_values_to_coords: If True, add returned values to coordinates. These
94+
will propagate into the statistics, and can therefore be used for
95+
binning. Default: False.
9296
"""
9397
self._interpolation = interpolation
9498
self._compute = compute
9599
self._add_nan_mask = add_nan_mask
96100
self._process_chunk_fn = process_chunk_fn
101+
self._add_values_to_coords = add_values_to_coords
97102

98103
@abc.abstractmethod
99104
def _load_chunk_from_source(
@@ -149,4 +154,10 @@ def _compute_and_keep_dtype(x: xr.DataArray) -> xr.DataArray:
149154

150155
if self._add_nan_mask:
151156
chunk = add_nan_mask_to_data(chunk)
157+
158+
if self._add_values_to_coords:
159+
chunk = xarray_tree.map_structure(
160+
lambda da: da.assign_coords(values_as_coord=da), chunk
161+
)
162+
152163
return chunk

weatherbenchX/data_loaders/sparse_parquet.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
# limitations under the License.
1414
"""Data loaders for tabular data stored in Parquet format."""
1515

16+
from collections.abc import Hashable
1617
import functools
1718
import os
18-
from typing import Callable, Hashable, Mapping, Optional, Sequence, Union
19+
from typing import Callable, Mapping, Optional, Sequence, Union
1920
import numpy as np
2021
import pandas as pd
2122
import pyarrow
22-
from weatherbenchX import interpolations
2323
from weatherbenchX.data_loaders import base
2424
import xarray as xr
2525

@@ -91,7 +91,6 @@ def __init__(
9191
coordinate_variables: Sequence[str] = (),
9292
split_variables: bool = False,
9393
dropna: bool = False,
94-
add_nan_mask: bool = False,
9594
tolerance: Optional[
9695
np.timedelta64 | tuple[np.timedelta64, np.timedelta64]
9796
] = None,
@@ -102,8 +101,7 @@ def __init__(
102101
observation_dim: Optional[str] = None,
103102
file_tolerance: np.timedelta64 = np.timedelta64(1, 'h'),
104103
preprocessing_fn: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None,
105-
interpolation: Optional[interpolations.Interpolation] = None,
106-
process_chunk_fn: Optional[Callable[[xr.Dataset], xr.Dataset]] = None,
104+
**kwargs,
107105
):
108106
"""Init.
109107
@@ -123,10 +121,6 @@ def __init__(
123121
dropna: Whether to drop missing values. If split_variables is True, values
124122
will be dropped for each variable separately. Otherwise, only indices
125123
where all variables are non-NaN will be returned.
126-
add_nan_mask: Adds a boolean coordinate named 'mask' to each variable
127-
(variables will be split into DataArrays if they aren't already), with
128-
False indicating NaN values. To be used for masked aggregation. Default:
129-
False.
130124
tolerance: (Optional) Tolerance around the given valid time. If tolerance
131125
is a single timedelta, data within valid_time +/- tolerance will be
132126
returned. If tolerance is a 2-tuple of timedeltas, data within
@@ -153,16 +147,12 @@ def __init__(
153147
1h
154148
preprocessing_fn: (Optional) Function to apply to the dataframe after
155149
reading.
156-
interpolation: (Optional) Interpolation to be applied to the data.
157-
process_chunk_fn: (Optional) Function to apply to the chunk of data after
158-
loading.
150+
**kwargs: Additional keyword arguments passed to the base DataLoader.
159151
"""
160152

161153
super().__init__(
162-
interpolation=interpolation,
163154
compute=False, # Data is already loaded.
164-
add_nan_mask=add_nan_mask,
165-
process_chunk_fn=process_chunk_fn,
155+
**kwargs
166156
)
167157
self._path = path
168158
if partitioned_by not in ['hour', 'day', 'month']:
@@ -479,7 +469,6 @@ def __init__(
479469
time_dim: str,
480470
split_variables: bool = False,
481471
dropna: bool = False,
482-
add_nan_mask: bool = False,
483472
tolerance: Optional[np.timedelta64] = None,
484473
partitioned_by: str = 'month',
485474
rename_variables: Optional[Mapping[str, str]] = None,
@@ -488,8 +477,7 @@ def __init__(
488477
pick_closest_duplicate_by: Optional[str] = None,
489478
file_tolerance: np.timedelta64 = np.timedelta64(1, 'h'),
490479
preprocessing_fn: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None,
491-
interpolation: Optional[interpolations.Interpolation] = None,
492-
process_chunk_fn: Optional[Callable[[xr.Dataset], xr.Dataset]] = None,
480+
**kwargs,
493481
):
494482
def metar_preprocessing_fn(
495483
df: pd.DataFrame,
@@ -521,7 +509,6 @@ def metar_preprocessing_fn(
521509
observation_dim='stationName',
522510
split_variables=split_variables,
523511
dropna=dropna,
524-
add_nan_mask=add_nan_mask,
525512
tolerance=tolerance,
526513
partitioned_by=partitioned_by,
527514
rename_variables=METAR_TO_ERA5_NAMES,
@@ -532,6 +519,5 @@ def metar_preprocessing_fn(
532519
preprocessing_fn=functools.partial(
533520
metar_preprocessing_fn, preprocessing_fn=preprocessing_fn
534521
),
535-
interpolation=interpolation,
536-
process_chunk_fn=process_chunk_fn,
522+
**kwargs,
537523
)

weatherbenchX/data_loaders/xarray_loaders.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# limitations under the License.
1414
"""Data loaders for reading gridded Zarr files."""
1515

16-
from typing import Any, Callable, Hashable, Iterable, Mapping, Optional, Union
16+
from collections.abc import Hashable
17+
from typing import Any, Callable, Iterable, Mapping, Optional, Union
18+
19+
from absl import logging
1720
import numpy as np
18-
from weatherbenchX import interpolations
1921
from weatherbenchX.data_loaders import base
2022
import xarray as xr
21-
from absl import logging
2223

2324

2425
def _rename_dataset(
@@ -63,11 +64,8 @@ def __init__(
6364
rename_dimensions: Optional[Union[Mapping[str, str], str]] = 'ecmwf',
6465
automatically_convert_lat_lon_to_latitude_longitude: bool = True,
6566
rename_variables: Optional[Mapping[str, str]] = None,
66-
interpolation: Optional[interpolations.Interpolation] = None,
67-
compute: bool = True,
68-
add_nan_mask: bool = False,
6967
preprocessing_fn: Optional[Callable[[xr.Dataset], xr.Dataset]] = None,
70-
process_chunk_fn: Optional[Callable[[xr.Dataset], xr.Dataset]] = None,
68+
**kwargs,
7169
):
7270
"""Init.
7371
@@ -91,16 +89,9 @@ def __init__(
9189
automatically convert 'lat' and 'lon' dimensions to 'latitude' and
9290
'longitude'. Default: True.
9391
rename_variables: (Optional) Dictionary of variables to rename.
94-
interpolation: (Optional) Interpolation instance.
95-
compute: Whether to load data into memory. Default: True.
96-
add_nan_mask: Adds a boolean coordinate named 'mask' to each variable
97-
(variables will be split into DataArrays if they aren't already), with
98-
False indicating NaN values. To be used for masked aggregation. Default:
99-
False.
10092
preprocessing_fn: (Optional) A function that is applied to the dataset
10193
right after it is opened.
102-
process_chunk_fn: (Optional) A function that is applied to each chunk
103-
after loading, interpolation and compute, but before computing a mask.
94+
**kwargs: Keyword arguments to pass to base.DataLoader.
10495
"""
10596
if path is not None and ds is not None:
10697
raise ValueError('Only one of path or ds can be specified, not both.')
@@ -120,12 +111,7 @@ def __init__(
120111
self._preprocessing_fn = preprocessing_fn
121112

122113
self._preprocessed = False
123-
super().__init__(
124-
interpolation=interpolation,
125-
compute=compute,
126-
add_nan_mask=add_nan_mask,
127-
process_chunk_fn=process_chunk_fn,
128-
)
114+
super().__init__(**kwargs)
129115

130116
def maybe_prepare_dataset(self):
131117
"""Prepares the dataset (reads and preprocesses it, if not already done)."""

0 commit comments

Comments
 (0)