Skip to content

Commit 601c14e

Browse files
authored
Merge pull request #858 from davidhassell/data-axes
Fix bug that caused `Data._axes` to be incorrect after a call to `cf.Field.collapse`
2 parents 9e2d650 + c6504e0 commit 601c14e

File tree

3 files changed

+45
-1
lines changed

3 files changed

+45
-1
lines changed

Changelog.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ version 3.17.0
1313
``key`` (https://github.com/NCAS-CMS/cf-python/issues/802)
1414
* New keyword parameter to `cf.histogram`: ``density``
1515
(https://github.com/NCAS-CMS/cf-python/issues/794)
16+
* Fix bug that caused `Data._axes` to be incorrect after a call to
17+
`cf.Field.collapse`
18+
(https://github.com/NCAS-CMS/cf-python/issues/857)
1619
* Changed dependency: ``Python>=3.9.0``
1720
* Changed dependency: ``numpy>=2.0.0``
1821
* Changed dependency: ``cfdm>=1.12.0.0, <1.12.1.0``

cf/data/utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,8 +402,14 @@ def collapse(
402402
weights, axis)``.
403403
404404
"""
405+
original_size = d.size
406+
if axis is None:
407+
axis = range(d.ndim)
408+
else:
409+
axis = d._parse_axes(axis)
410+
405411
kwargs = {
406-
"axis": axis,
412+
"axis": tuple(axis),
407413
"keepdims": keepdims,
408414
"split_every": split_every,
409415
"mtol": mtol,
@@ -424,6 +430,14 @@ def collapse(
424430
dx = func(dx, **kwargs)
425431
d._set_dask(dx)
426432

433+
if not keepdims:
434+
# Remove collapsed axis names
435+
d._axes = [a for i, a in enumerate(d._axes) if i not in axis]
436+
437+
if d.size != original_size:
438+
# Remove the out-dated HDF5 chunking strategy
439+
d.nc_clear_hdf5_chunksizes()
440+
427441
return d, weights
428442

429443

cf/test/test_Data.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4657,6 +4657,33 @@ def test_Data_is_masked(self):
46574657
self.assertTrue(d[0].is_masked)
46584658
self.assertFalse(d[1].is_masked)
46594659

4660+
def test_Data_collapse_axes_hdf_chunks(self):
4661+
"""Test that _axes and hdf_chunks are updated after a collapse."""
4662+
d = cf.Data([[1, 2, 3, 4]])
4663+
chunks = d.shape
4664+
d.nc_set_hdf5_chunksizes(chunks)
4665+
e = d.mean(axes=1)
4666+
self.assertEqual(d._axes, ("dim0", "dim1"))
4667+
self.assertEqual(d.nc_hdf5_chunksizes(), chunks)
4668+
4669+
e = d.mean(axes=1)
4670+
self.assertNotEqual(e.size, d.size)
4671+
self.assertEqual(e._axes, d._axes)
4672+
self.assertEqual(e.nc_hdf5_chunksizes(), None)
4673+
4674+
e = d.mean(axes=1, squeeze=True)
4675+
self.assertEqual(e._axes, d._axes[:1])
4676+
self.assertEqual(e.nc_hdf5_chunksizes(), None)
4677+
4678+
e = d.mean(axes=0)
4679+
self.assertEqual(e.size, d.size)
4680+
self.assertEqual(e._axes, d._axes)
4681+
self.assertEqual(e.nc_hdf5_chunksizes(), chunks)
4682+
4683+
e = d.mean(axes=0, squeeze=True)
4684+
self.assertEqual(e._axes, d._axes[1:])
4685+
self.assertEqual(e.nc_hdf5_chunksizes(), chunks)
4686+
46604687

46614688
if __name__ == "__main__":
46624689
print("Run date:", datetime.datetime.now())

0 commit comments

Comments
 (0)