|
8 | 8 |
|
9 | 9 | import matplotlib.pyplot as plt |
10 | 10 | import numpy as np |
| 11 | +import pandas |
11 | 12 | import pytest |
12 | 13 | import xarray as xr |
13 | 14 | from shapely.geometry import LineString, Point, Polygon, box |
@@ -146,6 +147,96 @@ def test_get_depth_name_missing() -> None: |
146 | 147 | dataset.ems.get_depth_name() |
147 | 148 |
|
148 | 149 |
|
| 150 | +@pytest.mark.parametrize('include_time', [True, False]) |
| 151 | +@pytest.mark.parametrize('include_depth', [True, False]) |
| 152 | +def test_select_variables( |
| 153 | + include_time: bool, |
| 154 | + include_depth: bool, |
| 155 | +): |
| 156 | + # Generate a dataset with some random data. |
| 157 | + # Time and depth dimensions are inluded or omitted based on the test arguments. |
| 158 | + generator = np.random.default_rng() |
| 159 | + expected_coords = {'y', 'x'} |
| 160 | + |
| 161 | + x_size, y_size = 5, 6 |
| 162 | + dataset = xr.Dataset({ |
| 163 | + 'x': (['x'], np.arange(x_size), {'units': 'degrees_east'}), |
| 164 | + 'y': (['y'], np.arange(y_size), {'units': 'degrees_north'}), |
| 165 | + 'colour': (['y', 'x'], np.arange(x_size * y_size).reshape((y_size, x_size)), {}), |
| 166 | + 'flavour': (['y', 'x'], np.arange(x_size * y_size).reshape((y_size, x_size)), {}), |
| 167 | + }) |
| 168 | + |
| 169 | + if include_time: |
| 170 | + expected_coords.add('time') |
| 171 | + dataset = dataset.assign_coords({ |
| 172 | + 'time': xr.DataArray( |
| 173 | + dims=['time'], |
| 174 | + data=pandas.date_range('2023-08-01', '2023-08-24'), |
| 175 | + ), |
| 176 | + }) |
| 177 | + dataset['time'].encoding['units'] = 'days since 1990-01-01 00:00:00 +10:00' |
| 178 | + time_size = dataset['time'].size |
| 179 | + dataset = dataset.assign({ |
| 180 | + 'eta': xr.DataArray( |
| 181 | + dims=['time', 'y', 'x'], |
| 182 | + data=generator.uniform(-1.0, 1.0, (time_size, y_size, x_size)), |
| 183 | + attrs={'standard_name': 'sea_surface_height'}, |
| 184 | + ), |
| 185 | + }) |
| 186 | + |
| 187 | + if include_depth: |
| 188 | + expected_coords.add('depth') |
| 189 | + depth_size = 4 |
| 190 | + dataset = dataset.assign_coords({ |
| 191 | + 'depth': xr.DataArray( |
| 192 | + dims=['depth'], |
| 193 | + data=np.linspace(-10, 0, depth_size), |
| 194 | + attrs={'standard_name': 'depth', 'positive': 'up'}, |
| 195 | + ) |
| 196 | + }) |
| 197 | + dataset = dataset.assign({ |
| 198 | + 'octarine': xr.DataArray( |
| 199 | + dims=['depth', 'y', 'x'], |
| 200 | + data=( |
| 201 | + generator.uniform(0.0, 1.0, (depth_size, y_size, x_size)) |
| 202 | + * np.linspace(100, 0, depth_size)[:, np.newaxis, np.newaxis] |
| 203 | + ), |
| 204 | + attrs={'standard_name': 'octarine_concentration'}, |
| 205 | + ), |
| 206 | + }) |
| 207 | + |
| 208 | + if include_depth and include_time: |
| 209 | + dataset = dataset.assign({ |
| 210 | + 'temperature': xr.DataArray( |
| 211 | + dims=['time', 'depth', 'y', 'x'], |
| 212 | + data=( |
| 213 | + generator.uniform(0, 3, (time_size, depth_size, y_size, x_size)) |
| 214 | + + np.linspace(2, 20, depth_size)[np.newaxis, :, np.newaxis, np.newaxis] |
| 215 | + ) |
| 216 | + ) |
| 217 | + }) |
| 218 | + |
| 219 | + # Test various variable subset selections |
| 220 | + # It should be possible to select all sorts of subsets. |
| 221 | + # This should preserve coordinate information, |
| 222 | + # even if no variables using those coordinates are included in the subset. |
| 223 | + variable_choices = [{'colour'}] |
| 224 | + if include_time: |
| 225 | + variable_choices.append({'colour', 'eta'}) |
| 226 | + if include_depth: |
| 227 | + variable_choices.append({'colour', 'octarine'}) |
| 228 | + if include_depth and include_time: |
| 229 | + variable_choices.append({'colour', 'eta', 'octarine', 'temperature'}) |
| 230 | + |
| 231 | + convention: Convention = SimpleConvention(dataset) |
| 232 | + for variables in variable_choices: |
| 233 | + subset = convention.select_variables(variables) |
| 234 | + expected_variables = variables | expected_coords |
| 235 | + assert set(subset.variables.keys()) == expected_variables |
| 236 | + for name in subset.variables.keys(): |
| 237 | + xr.testing.assert_equal(dataset[name], subset[name]) |
| 238 | + |
| 239 | + |
149 | 240 | def test_mask(): |
150 | 241 | dataset = xr.Dataset({ |
151 | 242 | 'values': (['z', 'y', 'x'], np.random.standard_normal((5, 10, 20))), |
|
0 commit comments