Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ New Features
All of Xarray's netCDF backends now support in-memory reads and writes
(:pull:`10624`).
By `Stephan Hoyer <https://github.com/shoyer>`_.
- :py:func:`merge` now supports merging :py:class:`DataTree` objects
(:issue:`9790`).
By `Stephan Hoyer <https://github.com/shoyer>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
6 changes: 1 addition & 5 deletions xarray/core/treenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
import sys
from collections.abc import Iterator, Mapping
from pathlib import PurePosixPath
from typing import (
TYPE_CHECKING,
Any,
TypeVar,
)
from typing import TYPE_CHECKING, Any, TypeVar

from xarray.core.types import Self
from xarray.core.utils import Frozen, is_dict_like
Expand Down
114 changes: 106 additions & 8 deletions xarray/structure/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import defaultdict
from collections.abc import Hashable, Iterable, Mapping, Sequence
from collections.abc import Set as AbstractSet
from typing import TYPE_CHECKING, Any, NamedTuple, Union
from typing import TYPE_CHECKING, Any, NamedTuple, Union, cast, overload

import pandas as pd

Expand Down Expand Up @@ -34,6 +34,7 @@
from xarray.core.coordinates import Coordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
from xarray.core.types import (
CombineAttrsOptions,
CompatOptions,
Expand Down Expand Up @@ -793,18 +794,96 @@ def merge_core(
return _MergeResult(variables, coord_names, dims, out_indexes, attrs)


def merge_trees(
trees: Iterable[DataTree],
compat: CompatOptions | CombineKwargDefault = _COMPAT_DEFAULT,
join: JoinOptions | CombineKwargDefault = _JOIN_DEFAULT,
fill_value: object = dtypes.NA,
combine_attrs: CombineAttrsOptions = "override",
) -> DataTree:
"""Merge specialized to DataTree objects."""
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree

if fill_value is not dtypes.NA:
# fill_value support dicts, which probably should be mapped to sub-groups?
raise NotImplementedError(
"fill_value is not yet supported for DataTree objects in merge"
)

node_lists: defaultdict[str, list[DataTree]] = defaultdict(list)
for tree in trees:
for key, node in tree.subtree_with_keys:
node_lists[key].append(node)

root_datasets = [node.dataset for node in node_lists.pop(".")]
root_ds = merge(
root_datasets, compat=compat, join=join, combine_attrs=combine_attrs
)
result = DataTree(dataset=root_ds)

def depth(kv):
return kv[0].count("/")

for key, nodes in sorted(node_lists.items(), key=depth):
# Merge datasets, including inherited indexes to ensure alignment.
datasets = [node.dataset for node in nodes]
merge_result = merge_core(
datasets,
compat=compat,
join=join,
combine_attrs=combine_attrs,
)
# Remove inherited coordinates/indexes/dimensions.
for var_name in list(merge_result.coord_names):
if not any(var_name in node._coord_variables for node in nodes):
del merge_result.variables[var_name]
merge_result.coord_names.remove(var_name)
for index_name in list(merge_result.indexes):
if not any(index_name in node._node_indexes for node in nodes):
del merge_result.indexes[index_name]
for dim in list(merge_result.dims):
if not any(dim in node._node_dims for node in nodes):
del merge_result.dims[dim]

merged_ds = Dataset._construct_direct(**merge_result._asdict())
result[key] = DataTree(dataset=merged_ds)

return result


@overload
def merge(
objects: Iterable[DataTree],
compat: CompatOptions | CombineKwargDefault = ...,
join: JoinOptions | CombineKwargDefault = ...,
fill_value: object = ...,
combine_attrs: CombineAttrsOptions = ...,
) -> DataTree: ...


@overload
def merge(
objects: Iterable[DataArray | Dataset | Coordinates | dict],
compat: CompatOptions | CombineKwargDefault = ...,
join: JoinOptions | CombineKwargDefault = ...,
fill_value: object = ...,
combine_attrs: CombineAttrsOptions = ...,
) -> Dataset: ...


def merge(
objects: Iterable[DataArray | CoercibleMapping],
objects: Iterable[DataTree | DataArray | Dataset | Coordinates | dict],
compat: CompatOptions | CombineKwargDefault = _COMPAT_DEFAULT,
join: JoinOptions | CombineKwargDefault = _JOIN_DEFAULT,
fill_value: object = dtypes.NA,
combine_attrs: CombineAttrsOptions = "override",
) -> Dataset:
) -> DataTree | Dataset:
"""Merge any number of xarray objects into a single Dataset as variables.

Parameters
----------
objects : iterable of Dataset or iterable of DataArray or iterable of dict-like
objects : iterable of DataArray, Dataset, DataTree or dict
Merge together all variables from these objects. If any of them are
DataArray objects, they must have a name.
compat : {"identical", "equals", "broadcast_equals", "no_conflicts", \
Expand Down Expand Up @@ -859,8 +938,9 @@ def merge(

Returns
-------
Dataset
Dataset with combined variables from each object.
Dataset or DataTree
Objects with combined variables from the inputs. If any inputs are a
DataTree, this will also be a DataTree. Otherwise it will be a Dataset.

Examples
--------
Expand Down Expand Up @@ -1023,13 +1103,31 @@ def merge(
from xarray.core.coordinates import Coordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree

objects = list(objects)

if any(isinstance(obj, DataTree) for obj in objects):
if not all(isinstance(obj, DataTree) for obj in objects):
raise TypeError(
"merge does not support mixed type arguments when one argument "
f"is a DataTree: {objects}"
)
trees = cast(list[DataTree], objects)
return merge_trees(
trees,
compat=compat,
join=join,
combine_attrs=combine_attrs,
fill_value=fill_value,
)

dict_like_objects = []
for obj in objects:
if not isinstance(obj, DataArray | Dataset | Coordinates | dict):
raise TypeError(
"objects must be an iterable containing only "
"Dataset(s), DataArray(s), and dictionaries."
"objects must be an iterable containing only DataTree(s), "
f"Dataset(s), DataArray(s), and dictionaries: {objects}"
)

if isinstance(obj, DataArray):
Expand Down
82 changes: 82 additions & 0 deletions xarray/tests/test_merge.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import re
import warnings

import numpy as np
Expand Down Expand Up @@ -867,3 +868,84 @@ def test_merge_auto_align(self):
with set_options(use_new_combine_kwarg_defaults=True):
with pytest.raises(ValueError, match="might be related to new default"):
expected.identical(ds2.merge(ds1))


class TestMergeDataTree:
def test_mixed(self) -> None:
tree = xr.DataTree()
ds = xr.Dataset()
with pytest.raises(
TypeError,
match="merge does not support mixed type arguments when one argument is a DataTree",
):
xr.merge([tree, ds]) # type: ignore[list-item]

def test_distinct(self) -> None:
tree1 = xr.DataTree.from_dict({"/a/b/c": 1})
tree2 = xr.DataTree.from_dict({"/a/d/e": 2})
expected = xr.DataTree.from_dict({"/a/b/c": 1, "/a/d/e": 2})
merged = xr.merge([tree1, tree2])
assert_equal(merged, expected)

def test_overlap(self) -> None:
tree1 = xr.DataTree.from_dict({"/a/b": 1})
tree2 = xr.DataTree.from_dict({"/a/c": 2})
tree3 = xr.DataTree.from_dict({"/a/d": 3})
expected = xr.DataTree.from_dict({"/a/b": 1, "/a/c": 2, "/a/d": 3})
merged = xr.merge([tree1, tree2, tree3])
assert_equal(merged, expected)

def test_inherited(self) -> None:
tree1 = xr.DataTree.from_dict({"/a/b": ("x", [1])}, coords={"x": [0]})
tree2 = xr.DataTree.from_dict({"/a/c": ("x", [2])})
expected = xr.DataTree.from_dict(
{"/a/b": ("x", [1]), "a/c": ("x", [2])}, coords={"x": [0]}
)
merged = xr.merge([tree1, tree2])
assert_equal(merged, expected)

def test_inherited_join(self) -> None:
tree1 = xr.DataTree.from_dict({"/a/b": ("x", [0, 1])}, coords={"x": [0, 1]})
tree2 = xr.DataTree.from_dict({"/a/c": ("x", [1, 2])}, coords={"x": [1, 2]})

expected = xr.DataTree.from_dict(
{"/a/b": ("x", [0, 1]), "a/c": ("x", [np.nan, 1])}, coords={"x": [0, 1]}
)
merged = xr.merge([tree1, tree2], join="left")
assert_equal(merged, expected)

expected = xr.DataTree.from_dict(
{"/a/b": ("x", [1, np.nan]), "a/c": ("x", [1, 2])}, coords={"x": [1, 2]}
)
merged = xr.merge([tree1, tree2], join="right")
assert_equal(merged, expected)

expected = xr.DataTree.from_dict(
{"/a/b": ("x", [1]), "a/c": ("x", [1])}, coords={"x": [1]}
)
merged = xr.merge([tree1, tree2], join="inner")
assert_equal(merged, expected)

expected = xr.DataTree.from_dict(
{"/a/b": ("x", [0, 1, np.nan]), "a/c": ("x", [np.nan, 1, 2])},
coords={"x": [0, 1, 2]},
)
merged = xr.merge([tree1, tree2], join="outer")
assert_equal(merged, expected)

with pytest.raises(
xr.AlignmentError,
match=re.escape("cannot align objects with join='exact'"),
):
xr.merge([tree1, tree2], join="exact")

def test_fill_value_errors(self) -> None:
trees = [xr.DataTree(), xr.DataTree()]

with pytest.raises(
NotImplementedError,
match=re.escape(
"fill_value is not yet supported for DataTree objects in merge"
),
):
xr.merge(trees, fill_value=None)
Loading