Skip to content

Commit a66687f

Browse files
authored
Merge pull request #830 from davidhassell/fix-cyclic-subspace
Fix bug where `cf.Field.subspace` doesn't always correctly handle global cyclic subspaces
2 parents fa5e6d8 + 228abd4 commit a66687f

File tree

8 files changed

+343
-101
lines changed

8 files changed

+343
-101
lines changed

Changelog.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,15 @@ version NEXTVERSION
3333
* Fix bug where `cf.normalize_slice` doesn't correctly
3434
handle certain cyclic slices
3535
(https://github.com/NCAS-CMS/cf-python/issues/774)
36+
* Fix bug where `cf.Field.subspace` doesn't always correctly handle
37+
global or near-global cyclic subspaces
38+
(https://github.com/NCAS-CMS/cf-python/issues/828)
3639
* New dependency: ``h5netcdf>=1.3.0``
3740
* New dependency: ``h5py>=3.10.0``
3841
* New dependency: ``s3fs>=2024.2.0``
3942
* Changed dependency: ``1.11.2.0<=cfdm<1.11.3.0``
4043
* Changed dependency: ``cfunits>=3.3.7``
4144

42-
4345
----
4446

4547
version 3.16.2

cf/cfimplementation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
TiePointIndex,
2727
)
2828
from .data import Data
29-
3029
from .data.array import (
3130
BoundsFromNodesArray,
3231
CellConnectivityArray,

cf/data/data.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from ..units import Units
4444
from .collapse import Collapse
4545
from .creation import generate_axis_identifiers, to_dask
46-
4746
from .dask_utils import (
4847
_da_ma_allclose,
4948
cf_asanyarray,

cf/dimensioncoordinate.py

Lines changed: 153 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
_inplace_enabled,
99
_inplace_enabled_define_and_cleanup,
1010
)
11-
from .functions import _DEPRECATION_ERROR_ATTRIBUTE, _DEPRECATION_ERROR_KWARGS
11+
from .functions import (
12+
_DEPRECATION_ERROR_ATTRIBUTE,
13+
_DEPRECATION_ERROR_KWARGS,
14+
bounds_combination_mode,
15+
)
1216
from .timeduration import TimeDuration
1317
from .units import Units
1418

@@ -246,6 +250,154 @@ def increasing(self):
246250
"""
247251
return self.direction()
248252

253+
@_inplace_enabled(default=False)
254+
def anchor(self, value, cell=False, parameters=None, inplace=False):
255+
"""Anchor the coordinate values.
256+
257+
By default, the coordinate values are transformed so that the
258+
first coordinate is the closest to *value* from above (below)
259+
for increasing (decreasing) coordinates.
260+
261+
If the *cell* parameter is True, then the coordinate values
262+
are transformed so that the first cell either contains
263+
*value*; or is the closest to cell to *value* from above
264+
(below) for increasing (decreasing) coordinates.
265+
266+
.. versionadded:: NEXTVERSION
267+
268+
.. seealso:: `period`, `roll`
269+
270+
:Parameters:
271+
272+
value: scalar array_like
273+
Anchor the coordinate values for the selected cyclic
274+
axis to the *value*. May be any numeric scalar object
275+
that can be converted to a `Data` object (which
276+
includes `numpy` and `Data` objects). If *value* has
277+
units then they must be compatible with those of the
278+
coordinates, otherwise it is assumed to have the same
279+
units as the coordinates.
280+
281+
The coordinate values are transformed so the first
282+
corodinate is the closest to *value* from above (for
283+
increasing coordinates), or the closest to *value* from
284+
above (for decreasing coordinates)
285+
286+
* Increasing coordinates with positive period, P,
287+
are transformed so that *value* lies in the
288+
half-open range (L-P, F], where F and L are the
289+
transformed first and last coordinate values,
290+
respectively.
291+
292+
..
293+
294+
* Decreasing coordinates with positive period, P,
295+
are transformed so that *value* lies in the
296+
half-open range (L+P, F], where F and L are the
297+
transformed first and last coordinate values,
298+
respectively.
299+
300+
*Parameter example:*
301+
If the original coordinates are ``0, 5, ..., 355``
302+
(evenly spaced) and the period is ``360`` then
303+
``value=0`` implies transformed coordinates of ``0,
304+
5, ..., 355``; ``value=-12`` implies transformed
305+
coordinates of ``-10, -5, ..., 345``; ``value=380``
306+
implies transformed coordinates of ``380, 385, ...,
307+
735``.
308+
309+
*Parameter example:*
310+
If the original coordinates are ``355, 350, ..., 0``
311+
(evenly spaced) and the period is ``360`` then
312+
``value=355`` implies transformed coordinates of
313+
``355, 350, ..., 0``; ``value=0`` implies
314+
transformed coordinates of ``0, -5, ..., -355``;
315+
``value=392`` implies transformed coordinates of
316+
``390, 385, ..., 35``.
317+
318+
cell: `bool`, optional
319+
If True, then the coordinate values are transformed so
320+
that the first cell either contains *value*, or is the
321+
closest to cell to *value* from above (below) for
322+
increasing (decreasing) coordinates.
323+
324+
If False (the default) then the coordinate values are
325+
transformed so that the first coordinate is the closest
326+
to *value* from above (below) for increasing
327+
(decreasing) coordinates.
328+
329+
parameters: `dict`, optional
330+
If a `dict` is provided then it will be updated
331+
in-place with parameters which describe thethe
332+
anchoring process.
333+
334+
{{inplace: `bool`, optional}}
335+
336+
:Returns:
337+
338+
`{{class}}` or `None`
339+
The anchored dimension coordinates, or `None` if the
340+
operation was in-place.
341+
342+
"""
343+
d = _inplace_enabled_define_and_cleanup(self)
344+
345+
period = d.period()
346+
if period is None:
347+
raise ValueError(f"Cyclic {d!r} has no period")
348+
349+
value = d._Data.asdata(value)
350+
if not value.Units:
351+
value = value.override_units(d.Units)
352+
elif not value.Units.equivalent(d.Units):
353+
raise ValueError(
354+
f"Anchor value has incompatible units: {value.Units!r}"
355+
)
356+
357+
if cell:
358+
c = d.upper_bounds.persist()
359+
else:
360+
d.persist(inplace=True)
361+
c = d.get_data(_fill_value=False)
362+
363+
if d.increasing:
364+
# Adjust value so it's in the range [c[0], c[0]+period)
365+
n = ((c[0] - value) / period).ceil()
366+
value1 = value + n * period
367+
shift = c.size - np.argmax((c - value1 >= 0).array)
368+
d.roll(0, shift, inplace=True)
369+
if cell:
370+
d0 = d[0].upper_bounds
371+
else:
372+
d0 = d.get_data(_fill_value=False)[0]
373+
374+
n = ((value - d0) / period).ceil()
375+
else:
376+
# Adjust value so it's in the range (c[0]-period, c[0]]
377+
n = ((c[0] - value) / period).floor()
378+
value1 = value + n * period
379+
shift = c.size - np.argmax((value1 - c >= 0).array)
380+
d.roll(0, shift, inplace=True)
381+
if cell:
382+
d0 = d[0].upper_bounds
383+
else:
384+
d0 = d.get_data(_fill_value=False)[0]
385+
386+
n = ((value - d0) / period).floor()
387+
388+
n.persist(inplace=True)
389+
if n:
390+
nperiod = n * period
391+
with bounds_combination_mode("OR"):
392+
d += nperiod
393+
else:
394+
nperiod = 0
395+
396+
if parameters is not None:
397+
parameters.update({"shift": shift, "nperiod": nperiod})
398+
399+
return d
400+
249401
def direction(self):
250402
"""Return True if the dimension coordinate values are
251403
increasing, otherwise return False.

cf/mixin/fielddomain.py

Lines changed: 44 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
bounds_combination_mode,
1919
normalize_slice,
2020
)
21-
from ..query import Query
21+
from ..query import Query, wi
2222
from ..units import Units
2323

2424
logger = logging.getLogger(__name__)
@@ -440,50 +440,41 @@ def _indices(self, config, data_axes, ancillary_mask, kwargs):
440440
if debug:
441441
logger.debug(" 1-d CASE 2:") # pragma: no cover
442442

443+
size = item.size
443444
if item.increasing:
444-
anchor0 = value.value[0]
445-
anchor1 = value.value[1]
445+
anchor = value.value[0]
446446
else:
447-
anchor0 = value.value[1]
448-
anchor1 = value.value[0]
449-
450-
a = self.anchor(axis, anchor0, dry_run=True)["roll"]
451-
b = self.flip(axis).anchor(axis, anchor1, dry_run=True)[
452-
"roll"
453-
]
454-
455-
size = item.size
456-
if abs(anchor1 - anchor0) >= item.period():
457-
if value.operator == "wo":
458-
set_start_stop = 0
459-
else:
460-
set_start_stop = -a
461-
462-
start = set_start_stop
463-
stop = set_start_stop
464-
elif a + b == size:
465-
b = self.anchor(axis, anchor1, dry_run=True)["roll"]
466-
if (b == a and value.operator == "wo") or not (
467-
b == a or value.operator == "wo"
468-
):
469-
set_start_stop = -a
470-
else:
471-
set_start_stop = 0
447+
anchor = value.value[1]
448+
449+
item = item.persist()
450+
parameters = {}
451+
item = item.anchor(anchor, parameters=parameters)
452+
n = np.roll(np.arange(size), parameters["shift"], 0)
453+
if value.operator == "wi":
454+
n = n[item == value]
455+
if not n.size:
456+
raise ValueError(
457+
f"No indices found from: {identity}={value!r}"
458+
)
472459

473-
start = set_start_stop
474-
stop = set_start_stop
460+
start = n[0]
461+
stop = n[-1] + 1
475462
else:
476-
if value.operator == "wo":
477-
start = b - size
478-
stop = -a + size
479-
else:
480-
start = -a
481-
stop = b - size
463+
# "wo" operator
464+
n = n[item == wi(*value.value)]
465+
if n.size == size:
466+
raise ValueError(
467+
f"No indices found from: {identity}={value!r}"
468+
)
482469

483-
if start == stop == 0:
484-
raise ValueError(
485-
f"No indices found from: {identity}={value!r}"
486-
)
470+
if n.size:
471+
start = n[-1] + 1
472+
stop = start - n.size
473+
else:
474+
start = size - parameters["shift"]
475+
stop = start + size
476+
if stop > size:
477+
stop -= size
487478

488479
index = slice(start, stop, 1)
489480

@@ -1287,77 +1278,34 @@ def anchor(
12871278
self, "anchor", kwargs
12881279
) # pragma: no cover
12891280

1290-
da_key, axis = self.domain_axis(axis, item=True)
1281+
axis = self.domain_axis(axis, key=True)
12911282

12921283
if dry_run:
12931284
f = self
12941285
else:
12951286
f = _inplace_enabled_define_and_cleanup(self)
12961287

1297-
dim = f.dimension_coordinate(filter_by_axis=(da_key,), default=None)
1288+
dim = f.dimension_coordinate(filter_by_axis=(axis,), default=None)
12981289
if dim is None:
12991290
raise ValueError(
13001291
"Can't shift non-cyclic "
1301-
f"{f.constructs.domain_axis_identity(da_key)!r} axis"
1292+
f"{f.constructs.domain_axis_identity(axis)!r} axis"
13021293
)
13031294

1304-
period = dim.period()
1305-
if period is None:
1306-
raise ValueError(f"Cyclic {dim.identity()!r} axis has no period")
1307-
1308-
value = f._Data.asdata(value)
1309-
if not value.Units:
1310-
value = value.override_units(dim.Units)
1311-
elif not value.Units.equivalent(dim.Units):
1312-
raise ValueError(
1313-
f"Anchor value has incompatible units: {value.Units!r}"
1314-
)
1315-
1316-
axis_size = axis.get_size()
1317-
1318-
if axis_size <= 1:
1319-
# Don't need to roll a size one axis
1320-
if dry_run:
1321-
return {"axis": da_key, "roll": 0, "nperiod": 0}
1322-
1323-
return f
1324-
1325-
c = dim.get_data(_fill_value=False)
1326-
1327-
if dim.increasing:
1328-
# Adjust value so it's in the range [c[0], c[0]+period)
1329-
n = ((c[0] - value) / period).ceil()
1330-
value1 = value + n * period
1331-
1332-
shift = axis_size - np.argmax((c - value1 >= 0).array)
1333-
if not dry_run:
1334-
f.roll(da_key, shift, inplace=True)
1335-
1336-
# Re-get dim
1337-
dim = f.dimension_coordinate(filter_by_axis=(da_key,))
1338-
# TODO CHECK n for dry run or not
1339-
n = ((value - dim.data[0]) / period).ceil()
1340-
else:
1341-
# Adjust value so it's in the range (c[0]-period, c[0]]
1342-
n = ((c[0] - value) / period).floor()
1343-
value1 = value + n * period
1344-
1345-
shift = axis_size - np.argmax((value1 - c >= 0).array)
1346-
1347-
if not dry_run:
1348-
f.roll(da_key, shift, inplace=True)
1349-
1350-
# Re-get dim
1351-
dim = f.dimension_coordinate(filter_by_axis=(da_key,))
1352-
# TODO CHECK n for dry run or not
1353-
n = ((value - dim.data[0]) / period).floor()
1295+
parameters = {"axis": axis}
1296+
dim = dim.anchor(value, parameters=parameters)
13541297

13551298
if dry_run:
1356-
return {"axis": da_key, "roll": shift, "nperiod": n * period}
1299+
return parameters
1300+
1301+
f.roll(axis, parameters["shift"], inplace=True)
13571302

1358-
if n:
1303+
if parameters["nperiod"]:
1304+
# Get the rolled dimension coordinate and adjust the
1305+
# values by the non-zero integer multiple of 'period'
1306+
dim = f.dimension_coordinate(filter_by_axis=(axis,))
13591307
with bounds_combination_mode("OR"):
1360-
dim += n * period
1308+
dim += parameters["nperiod"]
13611309

13621310
return f
13631311

cf/regrid/regrid.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2465,7 +2465,6 @@ def create_esmpy_weights(
24652465
from netCDF4 import Dataset
24662466

24672467
from .. import __version__
2468-
24692468
from ..data.array.locks import netcdf_lock
24702469

24712470
if (

0 commit comments

Comments
 (0)