Skip to content

Commit d240a82

Browse files
authored
Improvements for xarray provider (geopython#1800)
* Manage non-cf-compliant time dimension * Manage datasets without a time dimension * Allow reversed slices also for axes * Convert also metadata to float64 for json output * Use named temporary file to enable netcdf4 engine * Make float64 conversion faster * Add netcdf output to xarray provider * Flake8 fixes * Fix bug when no time axis in data * Use new xarray interface * Add test for zarr dataset without time dimension * Avoid errors if missing long_name * Manage zarr and netcdf output in the same way * Revert "Manage zarr and netcdf output in the same way" This reverts commit 0b09281. * Revert "Add netcdf output to xarray provider" This reverts commit 9f72bf7.
1 parent 474cb60 commit d240a82

File tree

2 files changed

+129
-55
lines changed

2 files changed

+129
-55
lines changed

pygeoapi/provider/xarray_.py

Lines changed: 103 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,19 @@ def __init__(self, provider_def):
8585
else:
8686
data_to_open = self.data
8787

88-
self._data = open_func(data_to_open)
88+
try:
89+
self._data = open_func(data_to_open)
90+
except ValueError as err:
91+
# Manage non-cf-compliant time dimensions
92+
if 'time' in str(err):
93+
self._data = open_func(self.data, decode_times=False)
94+
else:
95+
raise err
96+
8997
self.storage_crs = self._parse_storage_crs(provider_def)
9098
self._coverage_properties = self._get_coverage_properties()
9199

92-
self.axes = [self._coverage_properties['x_axis_label'],
93-
self._coverage_properties['y_axis_label'],
94-
self._coverage_properties['time_axis_label']]
100+
self.axes = self._coverage_properties['axes']
95101

96102
self.get_fields()
97103
except Exception as err:
@@ -101,15 +107,15 @@ def __init__(self, provider_def):
101107
def get_fields(self):
102108
if not self._fields:
103109
for key, value in self._data.variables.items():
104-
if len(value.shape) >= 3:
110+
if key not in self._data.coords:
105111
LOGGER.debug('Adding variable')
106112
dtype = value.dtype
107113
if dtype.name.startswith('float'):
108114
dtype = 'number'
109115

110116
self._fields[key] = {
111117
'type': dtype,
112-
'title': value.attrs['long_name'],
118+
'title': value.attrs.get('long_name'),
113119
'x-ogc-unit': value.attrs.get('units')
114120
}
115121

@@ -142,9 +148,9 @@ def query(self, properties=[], subsets={}, bbox=[], bbox_crs=4326,
142148

143149
data = self._data[[*properties]]
144150

145-
if any([self._coverage_properties['x_axis_label'] in subsets,
146-
self._coverage_properties['y_axis_label'] in subsets,
147-
self._coverage_properties['time_axis_label'] in subsets,
151+
if any([self._coverage_properties.get('x_axis_label') in subsets,
152+
self._coverage_properties.get('y_axis_label') in subsets,
153+
self._coverage_properties.get('time_axis_label') in subsets,
148154
datetime_ is not None]):
149155

150156
LOGGER.debug('Creating spatio-temporal subset')
@@ -163,18 +169,36 @@ def query(self, properties=[], subsets={}, bbox=[], bbox_crs=4326,
163169
self._coverage_properties['y_axis_label'] in subsets,
164170
len(bbox) > 0]):
165171
msg = 'bbox and subsetting by coordinates are exclusive'
166-
LOGGER.warning(msg)
172+
LOGGER.error(msg)
167173
raise ProviderQueryError(msg)
168174
else:
169-
query_params[self._coverage_properties['x_axis_label']] = \
170-
slice(bbox[0], bbox[2])
171-
query_params[self._coverage_properties['y_axis_label']] = \
172-
slice(bbox[1], bbox[3])
175+
x_axis_label = self._coverage_properties['x_axis_label']
176+
x_coords = data.coords[x_axis_label]
177+
if x_coords.values[0] > x_coords.values[-1]:
178+
LOGGER.debug(
179+
'Reversing slicing of x axis from high to low'
180+
)
181+
query_params[x_axis_label] = slice(bbox[2], bbox[0])
182+
else:
183+
query_params[x_axis_label] = slice(bbox[0], bbox[2])
184+
y_axis_label = self._coverage_properties['y_axis_label']
185+
y_coords = data.coords[y_axis_label]
186+
if y_coords.values[0] > y_coords.values[-1]:
187+
LOGGER.debug(
188+
'Reversing slicing of y axis from high to low'
189+
)
190+
query_params[y_axis_label] = slice(bbox[3], bbox[1])
191+
else:
192+
query_params[y_axis_label] = slice(bbox[1], bbox[3])
173193

174194
LOGGER.debug('bbox_crs is not currently handled')
175195

176196
if datetime_ is not None:
177-
if self._coverage_properties['time_axis_label'] in subsets:
197+
if self._coverage_properties['time_axis_label'] is None:
198+
msg = 'Dataset does not contain a time axis'
199+
LOGGER.error(msg)
200+
raise ProviderQueryError(msg)
201+
elif self._coverage_properties['time_axis_label'] in subsets:
178202
msg = 'datetime and temporal subsetting are exclusive'
179203
LOGGER.error(msg)
180204
raise ProviderQueryError(msg)
@@ -196,32 +220,36 @@ def query(self, properties=[], subsets={}, bbox=[], bbox_crs=4326,
196220
LOGGER.warning(err)
197221
raise ProviderQueryError(err)
198222

199-
if (any([data.coords[self.x_field].size == 0,
200-
data.coords[self.y_field].size == 0,
201-
data.coords[self.time_field].size == 0])):
223+
if any(size == 0 for size in data.sizes.values()):
202224
msg = 'No data found'
203225
LOGGER.warning(msg)
204226
raise ProviderNoDataError(msg)
205227

228+
if format_ == 'json':
229+
# json does not support float32
230+
data = _convert_float32_to_float64(data)
231+
206232
out_meta = {
207233
'bbox': [
208234
data.coords[self.x_field].values[0],
209235
data.coords[self.y_field].values[0],
210236
data.coords[self.x_field].values[-1],
211237
data.coords[self.y_field].values[-1]
212238
],
213-
"time": [
214-
_to_datetime_string(data.coords[self.time_field].values[0]),
215-
_to_datetime_string(data.coords[self.time_field].values[-1])
216-
],
217239
"driver": "xarray",
218240
"height": data.sizes[self.y_field],
219241
"width": data.sizes[self.x_field],
220-
"time_steps": data.sizes[self.time_field],
221242
"variables": {var_name: var.attrs
222243
for var_name, var in data.variables.items()}
223244
}
224245

246+
if self.time_field is not None:
247+
out_meta['time'] = [
248+
_to_datetime_string(data.coords[self.time_field].values[0]),
249+
_to_datetime_string(data.coords[self.time_field].values[-1]),
250+
]
251+
out_meta["time_steps"] = data.sizes[self.time_field]
252+
225253
LOGGER.debug('Serializing data in memory')
226254
if format_ == 'json':
227255
LOGGER.debug('Creating output in CoverageJSON')
@@ -230,9 +258,11 @@ def query(self, properties=[], subsets={}, bbox=[], bbox_crs=4326,
230258
LOGGER.debug('Returning data in native zarr format')
231259
return _get_zarr_data(data)
232260
else: # return data in native format
233-
with tempfile.TemporaryFile() as fp:
261+
with tempfile.NamedTemporaryFile() as fp:
234262
LOGGER.debug('Returning data in native NetCDF format')
235-
fp.write(data.to_netcdf())
263+
data.to_netcdf(
264+
fp.name
265+
) # we need to pass a string to be able to use the "netcdf4" engine # noqa
236266
fp.seek(0)
237267
return fp.read()
238268

@@ -249,7 +279,6 @@ def gen_covjson(self, metadata, data, fields):
249279

250280
LOGGER.debug('Creating CoverageJSON domain')
251281
minx, miny, maxx, maxy = metadata['bbox']
252-
mint, maxt = metadata['time']
253282

254283
selected_fields = {
255284
key: value for key, value in self.fields.items()
@@ -285,11 +314,6 @@ def gen_covjson(self, metadata, data, fields):
285314
'start': maxy,
286315
'stop': miny,
287316
'num': metadata['height']
288-
},
289-
self.time_field: {
290-
'start': mint,
291-
'stop': maxt,
292-
'num': metadata['time_steps']
293317
}
294318
},
295319
'referencing': [{
@@ -304,6 +328,14 @@ def gen_covjson(self, metadata, data, fields):
304328
'ranges': {}
305329
}
306330

331+
if self.time_field is not None:
332+
mint, maxt = metadata['time']
333+
cj['domain']['axes'][self.time_field] = {
334+
'start': mint,
335+
'stop': maxt,
336+
'num': metadata['time_steps'],
337+
}
338+
307339
for key, value in selected_fields.items():
308340
parameter = {
309341
'type': 'Parameter',
@@ -322,21 +354,25 @@ def gen_covjson(self, metadata, data, fields):
322354
cj['parameters'][key] = parameter
323355

324356
data = data.fillna(None)
325-
data = _convert_float32_to_float64(data)
326357

327358
try:
328359
for key, value in selected_fields.items():
329360
cj['ranges'][key] = {
330361
'type': 'NdArray',
331362
'dataType': value['type'],
332363
'axisNames': [
333-
'y', 'x', self._coverage_properties['time_axis_label']
364+
'y', 'x'
334365
],
335366
'shape': [metadata['height'],
336-
metadata['width'],
337-
metadata['time_steps']]
367+
metadata['width']]
338368
}
339369
cj['ranges'][key]['values'] = data[key].values.flatten().tolist() # noqa
370+
371+
if self.time_field is not None:
372+
cj['ranges'][key]['axisNames'].append(
373+
self._coverage_properties['time_axis_label']
374+
)
375+
cj['ranges'][key]['shape'].append(metadata['time_steps'])
340376
except IndexError as err:
341377
LOGGER.warning(err)
342378
raise ProviderQueryError('Invalid query parameter')
@@ -382,31 +418,37 @@ def _get_coverage_properties(self):
382418
self._data.coords[self.x_field].values[-1],
383419
self._data.coords[self.y_field].values[-1],
384420
],
385-
'time_range': [
386-
_to_datetime_string(
387-
self._data.coords[self.time_field].values[0]
388-
),
389-
_to_datetime_string(
390-
self._data.coords[self.time_field].values[-1]
391-
)
392-
],
393421
'bbox_crs': 'http://www.opengis.net/def/crs/OGC/1.3/CRS84',
394422
'crs_type': 'GeographicCRS',
395423
'x_axis_label': self.x_field,
396424
'y_axis_label': self.y_field,
397-
'time_axis_label': self.time_field,
398425
'width': self._data.sizes[self.x_field],
399426
'height': self._data.sizes[self.y_field],
400-
'time': self._data.sizes[self.time_field],
401-
'time_duration': self.get_time_coverage_duration(),
402427
'bbox_units': 'degrees',
403-
'resx': np.abs(self._data.coords[self.x_field].values[1]
404-
- self._data.coords[self.x_field].values[0]),
405-
'resy': np.abs(self._data.coords[self.y_field].values[1]
406-
- self._data.coords[self.y_field].values[0]),
407-
'restime': self.get_time_resolution()
428+
'resx': np.abs(
429+
self._data.coords[self.x_field].values[1]
430+
- self._data.coords[self.x_field].values[0]
431+
),
432+
'resy': np.abs(
433+
self._data.coords[self.y_field].values[1]
434+
- self._data.coords[self.y_field].values[0]
435+
),
408436
}
409437

438+
if self.time_field is not None:
439+
properties['time_axis_label'] = self.time_field
440+
properties['time_range'] = [
441+
_to_datetime_string(
442+
self._data.coords[self.time_field].values[0]
443+
),
444+
_to_datetime_string(
445+
self._data.coords[self.time_field].values[-1]
446+
),
447+
]
448+
properties['time'] = self._data.sizes[self.time_field]
449+
properties['time_duration'] = self.get_time_coverage_duration()
450+
properties['restime'] = self.get_time_resolution()
451+
410452
# Update properties based on the xarray's CRS
411453
epsg_code = self.storage_crs.to_epsg()
412454
LOGGER.debug(f'{epsg_code}')
@@ -425,10 +467,12 @@ def _get_coverage_properties(self):
425467

426468
properties['axes'] = [
427469
properties['x_axis_label'],
428-
properties['y_axis_label'],
429-
properties['time_axis_label']
470+
properties['y_axis_label']
430471
]
431472

473+
if self.time_field is not None:
474+
properties['axes'].append(properties['time_axis_label'])
475+
432476
return properties
433477

434478
@staticmethod
@@ -455,7 +499,8 @@ def get_time_resolution(self):
455499
:returns: time resolution string
456500
"""
457501

458-
if self._data[self.time_field].size > 1:
502+
if self.time_field is not None \
503+
and self._data[self.time_field].size > 1:
459504
time_diff = (self._data[self.time_field][1] -
460505
self._data[self.time_field][0])
461506

@@ -472,6 +517,9 @@ def get_time_coverage_duration(self):
472517
:returns: time coverage duration string
473518
"""
474519

520+
if self.time_field is None:
521+
return None
522+
475523
dur = self._data[self.time_field][-1] - self._data[self.time_field][0]
476524
ms_difference = dur.values.astype('timedelta64[ms]').astype(np.double)
477525

@@ -634,7 +682,7 @@ def _convert_float32_to_float64(data):
634682
for var_name in data.variables:
635683
if data[var_name].dtype == 'float32':
636684
og_attrs = data[var_name].attrs
637-
data[var_name] = data[var_name].astype('float64')
685+
data[var_name] = data[var_name].astype('float64', copy=False)
638686
data[var_name].attrs = og_attrs
639687

640688
return data

tests/test_xarray_zarr_provider.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from numpy import float64, int64
3131

3232
import pytest
33+
import xarray as xr
3334

3435
from pygeoapi.provider.xarray_ import XarrayProvider
3536
from pygeoapi.util import json_serial
@@ -53,6 +54,20 @@ def config():
5354
}
5455

5556

57+
@pytest.fixture()
58+
def config_no_time(tmp_path):
59+
ds = xr.open_zarr(path)
60+
ds = ds.sel(time=ds.time[0])
61+
ds = ds.drop_vars('time')
62+
ds.to_zarr(tmp_path / 'no_time.zarr')
63+
return {
64+
'name': 'zarr',
65+
'type': 'coverage',
66+
'data': str(tmp_path / 'no_time.zarr'),
67+
'format': {'name': 'zarr', 'mimetype': 'application/zip'},
68+
}
69+
70+
5671
def test_provider(config):
5772
p = XarrayProvider(config)
5873

@@ -85,3 +100,14 @@ def test_numpy_json_serial():
85100

86101
d = float64(500.00000005)
87102
assert json_serial(d) == 500.00000005
103+
104+
105+
def test_no_time(config_no_time):
106+
p = XarrayProvider(config_no_time)
107+
108+
assert len(p.fields) == 4
109+
assert p.axes == ['lon', 'lat']
110+
111+
coverage = p.query(format='json')
112+
113+
assert sorted(coverage['domain']['axes'].keys()) == ['x', 'y']

0 commit comments

Comments
 (0)