diff --git a/Changelog.rst b/Changelog.rst index 67108990b5..f0d5f95692 100644 --- a/Changelog.rst +++ b/Changelog.rst @@ -13,6 +13,9 @@ version 3.17.0 ``key`` (https://github.com/NCAS-CMS/cf-python/issues/802) * New keyword parameter to `cf.histogram`: ``density`` (https://github.com/NCAS-CMS/cf-python/issues/794) +* Fix bug that caused `Data._axes` to be incorrect after a call to + `cf.Field.collapse` + (https://github.com/NCAS-CMS/cf-python/issues/857) * Changed dependency: ``Python>=3.9.0`` * Changed dependency: ``numpy>=2.0.0`` * Changed dependency: ``cfdm>=1.12.0.0, <1.12.1.0`` diff --git a/cf/data/utils.py b/cf/data/utils.py index c1b1a63920..63b2a40e88 100644 --- a/cf/data/utils.py +++ b/cf/data/utils.py @@ -402,8 +402,14 @@ def collapse( weights, axis)``. """ + original_size = d.size + if axis is None: + axis = range(d.ndim) + else: + axis = d._parse_axes(axis) + kwargs = { - "axis": axis, + "axis": tuple(axis), "keepdims": keepdims, "split_every": split_every, "mtol": mtol, @@ -424,6 +430,14 @@ def collapse( dx = func(dx, **kwargs) d._set_dask(dx) + if not keepdims: + # Remove collapsed axis names + d._axes = [a for i, a in enumerate(d._axes) if i not in axis] + + if d.size != original_size: + # Remove the out-dated HDF5 chunking strategy + d.nc_clear_hdf5_chunksizes() + return d, weights diff --git a/cf/test/test_Data.py b/cf/test/test_Data.py index b25342afce..6692737993 100644 --- a/cf/test/test_Data.py +++ b/cf/test/test_Data.py @@ -4657,6 +4657,33 @@ def test_Data_is_masked(self): self.assertTrue(d[0].is_masked) self.assertFalse(d[1].is_masked) + def test_Data_collapse_axes_hdf_chunks(self): + """Test that _axes and hdf_chunks are updated after a collapse.""" + d = cf.Data([[1, 2, 3, 4]]) + chunks = d.shape + d.nc_set_hdf5_chunksizes(chunks) + e = d.mean(axes=1) + self.assertEqual(d._axes, ("dim0", "dim1")) + self.assertEqual(d.nc_hdf5_chunksizes(), chunks) + + e = d.mean(axes=1) + self.assertNotEqual(e.size, d.size) + self.assertEqual(e._axes, d._axes) + self.assertEqual(e.nc_hdf5_chunksizes(), None) + + e = d.mean(axes=1, squeeze=True) + self.assertEqual(e._axes, d._axes[:1]) + self.assertEqual(e.nc_hdf5_chunksizes(), None) + + e = d.mean(axes=0) + self.assertEqual(e.size, d.size) + self.assertEqual(e._axes, d._axes) + self.assertEqual(e.nc_hdf5_chunksizes(), chunks) + + e = d.mean(axes=0, squeeze=True) + self.assertEqual(e._axes, d._axes[1:]) + self.assertEqual(e.nc_hdf5_chunksizes(), chunks) + if __name__ == "__main__": print("Run date:", datetime.datetime.now())