diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 535ea50ea80..91295c45317 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -59,6 +59,10 @@ Deprecations Bug Fixes ~~~~~~~~~ +- :py:meth:`Dataset.map` now merges attrs from the function result and the original + using the ``drop_conflicts`` strategy when ``keep_attrs=True``, preserving attrs + set by the function (:issue:`11019`, :pull:`11020`). + By `Maximilian Roos `_. - Ensure that ``keep_attrs='drop'`` and ``keep_attrs=False`` remove attrs from result, even when there is only one xarray object given to ``apply_ufunc`` (:issue:`10982` :pull:`10997`). By `Julia Signell `_. diff --git a/xarray/computation/weighted.py b/xarray/computation/weighted.py index b311290aabf..d19cc4fea90 100644 --- a/xarray/computation/weighted.py +++ b/xarray/computation/weighted.py @@ -544,13 +544,33 @@ def _implementation(self, func, dim, **kwargs) -> DataArray: dataset = self.obj._to_temp_dataset() dataset = dataset.map(func, dim=dim, **kwargs) - return self.obj._from_temp_dataset(dataset) + result = self.obj._from_temp_dataset(dataset) + + # Clear attrs when keep_attrs is explicitly False + # (weighted operations can propagate attrs from weights through internal computations) + if kwargs.get("keep_attrs") is False: + result.attrs = {} + for var in result.coords.values(): + var.attrs = {} + + return result class DatasetWeighted(Weighted["Dataset"]): def _implementation(self, func, dim, **kwargs) -> Dataset: self._check_dim(dim) - return self.obj.map(func, dim=dim, **kwargs) + result = self.obj.map(func, dim=dim, **kwargs) + + # Clear attrs when keep_attrs is explicitly False + # (weighted operations can propagate attrs from weights through internal computations) + if kwargs.get("keep_attrs") is False: + result.attrs = {} + for var in result.data_vars.values(): + var.attrs = {} + for var in result.coords.values(): + var.attrs = {} + + return result def _inject_docstring(cls, cls_name): diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index e15f1077639..64e4625faaa 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6908,8 +6908,10 @@ def map( DataArray. keep_attrs : bool or None, optional If True, both the dataset's and variables' attributes (`attrs`) will be - copied from the original objects to the new ones. If False, the new dataset - and variables will be returned without copying the attributes. + combined from the original objects and the function results using the + ``drop_conflicts`` strategy: matching attrs are kept, conflicting attrs + are dropped. If False, the new dataset and variables will have only + the attributes set by the function. args : iterable, optional Positional arguments passed on to `func`. **kwargs : Any @@ -6958,16 +6960,19 @@ def map( coords = Coordinates._construct_direct(coords=coord_vars, indexes=indexes) if keep_attrs: + # Merge attrs from function result and original, dropping conflicts + from xarray.structure.merge import merge_attrs + for k, v in variables.items(): - v._copy_attrs_from(self.data_vars[k]) + v.attrs = merge_attrs( + [v.attrs, self.data_vars[k].attrs], "drop_conflicts" + ) for k, v in coords.items(): if k in self.coords: - v._copy_attrs_from(self.coords[k]) - else: - for v in variables.values(): - v.attrs = {} - for v in coords.values(): - v.attrs = {} + v.attrs = merge_attrs( + [v.attrs, self.coords[k].attrs], "drop_conflicts" + ) + # When keep_attrs=False, leave attrs as the function returned them attrs = self.attrs if keep_attrs else None return type(self)(variables, coords=coords, attrs=attrs) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index e079332780c..a64ceefb207 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -397,8 +397,10 @@ def map( # type: ignore[override] DataArray. keep_attrs : bool | None, optional If True, both the dataset's and variables' attributes (`attrs`) will be - copied from the original objects to the new ones. If False, the new dataset - and variables will be returned without copying the attributes. + combined from the original objects and the function results using the + ``drop_conflicts`` strategy: matching attrs are kept, conflicting attrs + are dropped. If False, the new dataset and variables will have only + the attributes set by the function. args : iterable, optional Positional arguments passed on to `func`. **kwargs : Any @@ -438,8 +440,13 @@ def map( # type: ignore[override] for k, v in self.data_vars.items() } if keep_attrs: + # Merge attrs from function result and original, dropping conflicts + from xarray.structure.merge import merge_attrs + for k, v in variables.items(): - v._copy_attrs_from(self.data_vars[k]) + v.attrs = merge_attrs( + [v.attrs, self.data_vars[k].attrs], "drop_conflicts" + ) attrs = self.attrs if keep_attrs else None # return type(self)(variables, attrs=attrs) return Dataset(variables, attrs=attrs) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index a11bf55b28f..c6d114b3f14 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6509,6 +6509,35 @@ def mixed_func(x): expected = xr.Dataset({"foo": 42, "bar": ("y", [4, 5])}) assert_identical(result, expected) + def test_map_preserves_function_attrs(self) -> None: + # Regression test for GH11019 + # Attrs added by function should be preserved in result + ds = xr.Dataset({"test": ("x", [1, 2, 3], {"original": "value"})}) + + def add_attr(da): + return da.assign_attrs(new_attr="foobar") + + # With keep_attrs=True: merge using drop_conflicts (no conflict here) + result = ds.map(add_attr, keep_attrs=True) + assert result["test"].attrs == {"original": "value", "new_attr": "foobar"} + + # With keep_attrs=False: function's attrs preserved + result = ds.map(add_attr, keep_attrs=False) + assert result["test"].attrs == {"original": "value", "new_attr": "foobar"} + + # When function modifies existing attr with keep_attrs=True, conflict is dropped + def modify_attr(da): + return da.assign_attrs(original="modified", extra="added") + + result = ds.map(modify_attr, keep_attrs=True) + assert result["test"].attrs == { + "extra": "added" + } # "original" dropped due to conflict + + # When function modifies existing attr with keep_attrs=False, function wins + result = ds.map(modify_attr, keep_attrs=False) + assert result["test"].attrs == {"original": "modified", "extra": "added"} + def test_apply_pending_deprecated_map(self) -> None: data = create_test_data() data.attrs["foo"] = "bar" diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index 5e913c00629..3d860b5b17a 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -786,6 +786,24 @@ def test_weighted_mean_keep_attrs_ds(): assert data.coords["dim_1"].attrs == result.coords["dim_1"].attrs +@pytest.mark.parametrize("as_dataset", (True, False)) +def test_weighted_operations_drop_coord_attrs(as_dataset): + # Test that coord attrs are cleared when keep_attrs=False + weights = DataArray(np.random.randn(2)) + ds = Dataset( + {"a": (["dim_0", "dim_1"], np.random.randn(2, 2), {"attr": "data"})}, + coords={"dim_1": ("dim_1", ["a", "b"], {"coord_attr": "value"})}, + ) + + data: DataArray | Dataset = ds if as_dataset else ds["a"] + + result = data.weighted(weights).mean(dim="dim_0", keep_attrs=False) + + # All attrs should be cleared + assert result.attrs == {} + assert result.coords["dim_1"].attrs == {} + + @pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean", "quantile")) @pytest.mark.parametrize("as_dataset", (True, False)) def test_weighted_bad_dim(operation, as_dataset):