Skip to content

Commit a3bd20d

Browse files
authored
Support DataArray objects and nested dicts in DataTree.from_dict (#10658)
* Support DataArray objects in DataTree.from_dict Fixes #9539, #9486 * Add docs * Add support for flattening in from_dict * Fix doctest * Add note about DataTree.from_dict vs DataArray.from_dict * Fix whats new * Require nested=True for processing nested items * Better error message * improve typing
1 parent 6470460 commit a3bd20d

File tree

7 files changed

+343
-39
lines changed

7 files changed

+343
-39
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ v2025.09.1 (unreleased)
1313
New Features
1414
~~~~~~~~~~~~
1515

16+
- :py:func:`DataTree.from_dict` now supports passing in ``DataArray`` and nested
17+
dictionary values, and has a ``coords`` argument for specifying coordinates as
18+
``DataArray`` objects (:pull:`10658`).
1619
- ``engine='netcdf4'`` now supports reading and writing in-memory netCDF files.
1720
All of Xarray's netCDF backends now support in-memory reads and writes
1821
(:pull:`10624`).

xarray/core/datatree.py

Lines changed: 212 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
import io
55
import itertools
66
import textwrap
7-
from collections import ChainMap
7+
from collections import ChainMap, defaultdict
88
from collections.abc import (
99
Callable,
1010
Hashable,
1111
Iterable,
1212
Iterator,
1313
Mapping,
1414
)
15+
from dataclasses import dataclass, field
1516
from html import escape
1617
from os import PathLike
1718
from typing import (
@@ -21,6 +22,7 @@
2122
Literal,
2223
NoReturn,
2324
ParamSpec,
25+
TypeAlias,
2426
TypeVar,
2527
Union,
2628
overload,
@@ -85,6 +87,7 @@
8587
DtCompatible,
8688
ErrorOptions,
8789
ErrorOptionsWithWarn,
90+
NestedDict,
8891
NetcdfWriteModes,
8992
T_ChunkDimFreq,
9093
T_ChunksFreq,
@@ -441,6 +444,20 @@ def map( # type: ignore[override]
441444
return Dataset(variables, attrs=attrs)
442445

443446

447+
FromDictDataValue: TypeAlias = "CoercibleValue | Dataset | DataTree | None"
448+
449+
450+
@dataclass
451+
class _CoordWrapper:
452+
value: CoercibleValue
453+
454+
455+
@dataclass
456+
class _DatasetArgs:
457+
data_vars: dict[str, CoercibleValue] = field(default_factory=dict)
458+
coords: dict[str, CoercibleValue] = field(default_factory=dict)
459+
460+
444461
class DataTree(
445462
NamedNode,
446463
DataTreeAggregations,
@@ -1154,51 +1171,215 @@ def drop_nodes(
11541171
result._replace_node(children=children_to_keep)
11551172
return result
11561173

1174+
@overload
1175+
@classmethod
1176+
def from_dict(
1177+
cls,
1178+
data: Mapping[str, FromDictDataValue] | None = ...,
1179+
coords: Mapping[str, CoercibleValue] | None = ...,
1180+
*,
1181+
name: str | None = ...,
1182+
nested: Literal[False] = ...,
1183+
) -> Self: ...
1184+
1185+
@overload
1186+
@classmethod
1187+
def from_dict(
1188+
cls,
1189+
data: (
1190+
Mapping[str, FromDictDataValue | NestedDict[FromDictDataValue]] | None
1191+
) = ...,
1192+
coords: Mapping[str, CoercibleValue | NestedDict[CoercibleValue]] | None = ...,
1193+
*,
1194+
name: str | None = ...,
1195+
nested: Literal[True] = ...,
1196+
) -> Self: ...
1197+
11571198
@classmethod
11581199
def from_dict(
11591200
cls,
1160-
d: Mapping[str, Dataset | DataTree | None],
1161-
/,
1201+
data: (
1202+
Mapping[str, FromDictDataValue | NestedDict[FromDictDataValue]] | None
1203+
) = None,
1204+
coords: Mapping[str, CoercibleValue | NestedDict[CoercibleValue]] | None = None,
1205+
*,
11621206
name: str | None = None,
1207+
nested: bool = False,
11631208
) -> Self:
11641209
"""
11651210
Create a datatree from a dictionary of data objects, organised by paths into the tree.
11661211
11671212
Parameters
11681213
----------
1169-
d : dict-like
1170-
A mapping from path names to xarray.Dataset or DataTree objects.
1214+
data : dict-like, optional
1215+
A mapping from path names to ``None`` (indicating an empty node),
1216+
``DataTree``, ``Dataset``, objects coercible into a ``DataArray`` or
1217+
a nested dictionary of any of the above types.
11711218
1172-
Path names are to be given as unix-like path. If path names
1173-
containing more than one part are given, new tree nodes will be
1174-
constructed as necessary.
1219+
Path names should be given as unix-like paths, either absolute
1220+
(/path/to/item) or relative to the root node (path/to/item). If path
1221+
names containing more than one part are given, new tree nodes will
1222+
be constructed automatically as necessary.
11751223
11761224
To assign data to the root node of the tree use "", ".", "/" or "./"
11771225
as the path.
1226+
coords : dict-like, optional
1227+
A mapping from path names to objects coercible into a DataArray, or
1228+
nested dictionaries of coercible objects.
11781229
name : Hashable | None, optional
11791230
Name for the root node of the tree. Default is None.
1231+
nested : bool, optional
1232+
If true, nested dictionaries in ``data`` and ``coords`` are
1233+
automatically flattened.
11801234
11811235
Returns
11821236
-------
11831237
DataTree
11841238
1239+
See also
1240+
--------
1241+
Dataset
1242+
11851243
Notes
11861244
-----
1187-
If your dictionary is nested you will need to flatten it before using this method.
1188-
"""
1189-
# Find any values corresponding to the root
1190-
d_cast = dict(d)
1191-
root_data = None
1192-
for key in ("", ".", "/", "./"):
1193-
if key in d_cast:
1194-
if root_data is not None:
1245+
``DataTree.from_dict`` serves a conceptually different purpose from
1246+
``Dataset.from_dict`` and ``DataArray.from_dict``. It converts a
1247+
hierarchy of Xarray objects into a DataTree, rather than converting pure
1248+
Python data structures.
1249+
1250+
Examples
1251+
--------
1252+
1253+
Construct a tree from a dict of Dataset objects:
1254+
1255+
>>> dt = DataTree.from_dict(
1256+
... {
1257+
... "/": Dataset(coords={"time": [1, 2, 3]}),
1258+
... "/ocean": Dataset(
1259+
... {
1260+
... "temperature": ("time", [4, 5, 6]),
1261+
... "salinity": ("time", [7, 8, 9]),
1262+
... }
1263+
... ),
1264+
... "/atmosphere": Dataset(
1265+
... {
1266+
... "temperature": ("time", [2, 3, 4]),
1267+
... "humidity": ("time", [3, 4, 5]),
1268+
... }
1269+
... ),
1270+
... }
1271+
... )
1272+
>>> dt
1273+
<xarray.DataTree>
1274+
Group: /
1275+
│ Dimensions: (time: 3)
1276+
│ Coordinates:
1277+
│ * time (time) int64 24B 1 2 3
1278+
├── Group: /ocean
1279+
│ Dimensions: (time: 3)
1280+
│ Data variables:
1281+
│ temperature (time) int64 24B 4 5 6
1282+
│ salinity (time) int64 24B 7 8 9
1283+
└── Group: /atmosphere
1284+
Dimensions: (time: 3)
1285+
Data variables:
1286+
temperature (time) int64 24B 2 3 4
1287+
humidity (time) int64 24B 3 4 5
1288+
1289+
Or equivalently, use a dict of values that can be converted into
1290+
`DataArray` objects, with syntax similar to the Dataset constructor:
1291+
1292+
>>> dt2 = DataTree.from_dict(
1293+
... data={
1294+
... "/ocean/temperature": ("time", [4, 5, 6]),
1295+
... "/ocean/salinity": ("time", [7, 8, 9]),
1296+
... "/atmosphere/temperature": ("time", [2, 3, 4]),
1297+
... "/atmosphere/humidity": ("time", [3, 4, 5]),
1298+
... },
1299+
... coords={"/time": [1, 2, 3]},
1300+
... )
1301+
>>> assert dt.identical(dt2)
1302+
1303+
Nested dictionaries are automatically flattened if ``nested=True``:
1304+
1305+
>>> DataTree.from_dict({"a": {"b": {"c": {"x": 1, "y": 2}}}}, nested=True)
1306+
<xarray.DataTree>
1307+
Group: /
1308+
└── Group: /a
1309+
└── Group: /a/b
1310+
└── Group: /a/b/c
1311+
Dimensions: ()
1312+
Data variables:
1313+
x int64 8B 1
1314+
y int64 8B 2
1315+
1316+
"""
1317+
if data is None:
1318+
data = {}
1319+
1320+
if coords is None:
1321+
coords = {}
1322+
1323+
if nested:
1324+
data_items = utils.flat_items(data)
1325+
coords_items = utils.flat_items(coords)
1326+
else:
1327+
data_items = data.items()
1328+
coords_items = coords.items()
1329+
for arg_name, items in [("data", data_items), ("coords", coords_items)]:
1330+
for key, value in items:
1331+
if isinstance(value, dict):
1332+
raise TypeError(
1333+
f"{arg_name} contains a dict value at {key=}, "
1334+
"which is not a valid argument to "
1335+
f"DataTree.from_dict() with nested=False: {value}"
1336+
)
1337+
1338+
# Canonicalize and unify paths between `data` and `coords`
1339+
flat_data_and_coords = itertools.chain(
1340+
data_items,
1341+
((k, _CoordWrapper(v)) for k, v in coords_items),
1342+
)
1343+
nodes: dict[NodePath, _CoordWrapper | FromDictDataValue] = {}
1344+
for key, value in flat_data_and_coords:
1345+
path = NodePath(key).absolute()
1346+
if path in nodes:
1347+
raise ValueError(
1348+
f"multiple entries found corresponding to node {str(path)!r}"
1349+
)
1350+
nodes[path] = value
1351+
1352+
# Merge nodes corresponding to DataArrays into Datasets
1353+
dataset_args: defaultdict[NodePath, _DatasetArgs] = defaultdict(_DatasetArgs)
1354+
for path in list(nodes):
1355+
node = nodes[path]
1356+
if node is not None and not isinstance(node, Dataset | DataTree):
1357+
if path.parent == path:
1358+
raise ValueError("cannot set DataArray value at root")
1359+
if path.parent in nodes:
11951360
raise ValueError(
1196-
"multiple entries found corresponding to the root node"
1361+
f"cannot set DataArray value at {str(path)!r} when "
1362+
f"parent node at {str(path.parent)!r} is also set"
11971363
)
1198-
root_data = d_cast.pop(key)
1364+
del nodes[path]
1365+
if isinstance(node, _CoordWrapper):
1366+
dataset_args[path.parent].coords[path.name] = node.value
1367+
else:
1368+
dataset_args[path.parent].data_vars[path.name] = node
1369+
for path, args in dataset_args.items():
1370+
try:
1371+
nodes[path] = Dataset(args.data_vars, args.coords)
1372+
except (ValueError, TypeError) as e:
1373+
raise type(e)(
1374+
"failed to construct xarray.Dataset for DataTree node at "
1375+
f"{str(path)!r} with data_vars={args.data_vars} and "
1376+
f"coords={args.coords}"
1377+
) from e
11991378

12001379
# Create the root node
1201-
if isinstance(root_data, DataTree):
1380+
root_data = nodes.pop(NodePath("/"), None)
1381+
if isinstance(root_data, cls):
1382+
# use cls so type-checkers understand this method returns Self
12021383
obj = root_data.copy()
12031384
obj.name = name
12041385
elif root_data is None or isinstance(root_data, Dataset):
@@ -1209,31 +1390,29 @@ def from_dict(
12091390
f"or DataTree, got {type(root_data)}"
12101391
)
12111392

1212-
def depth(item) -> int:
1213-
pathstr, _ = item
1214-
return len(NodePath(pathstr).parts)
1393+
def depth(item: tuple[NodePath, object]) -> int:
1394+
node_path, _ = item
1395+
return len(node_path.parts)
12151396

1216-
if d_cast:
1217-
# Populate tree with children determined from data_objects mapping
1397+
if nodes:
1398+
# Populate tree with children
12181399
# Sort keys by depth so as to insert nodes from root first (see GH issue #9276)
1219-
for path, data in sorted(d_cast.items(), key=depth):
1400+
for path, node in sorted(nodes.items(), key=depth):
12201401
# Create and set new node
1221-
if isinstance(data, DataTree):
1222-
new_node = data.copy()
1223-
elif isinstance(data, Dataset) or data is None:
1224-
new_node = cls(dataset=data)
1402+
if isinstance(node, DataTree):
1403+
new_node = node.copy()
1404+
elif isinstance(node, Dataset) or node is None:
1405+
new_node = cls(dataset=node)
12251406
else:
1226-
raise TypeError(f"invalid values: {data}")
1407+
raise TypeError(f"invalid values: {node}")
12271408
obj._set_item(
12281409
path,
12291410
new_node,
12301411
allow_overwrite=False,
12311412
new_nodes_along_path=True,
12321413
)
12331414

1234-
# TODO: figure out why mypy is raising an error here, likely something
1235-
# to do with the return type of Dataset.copy()
1236-
return obj # type: ignore[return-value]
1415+
return obj
12371416

12381417
def to_dict(self, relative: bool = False) -> dict[str, Dataset]:
12391418
"""

xarray/core/treenode.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ def __init__(self, *pathsegments):
4242
)
4343
# TODO should we also forbid suffixes to avoid node names with dots in them?
4444

45+
def absolute(self) -> Self:
46+
"""Convert into an absolute path."""
47+
return type(self)("/", *self.parts)
48+
4549

4650
class TreeNode:
4751
"""

xarray/core/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,10 @@ def __iter__(self, /) -> Iterator[_T_co | NestedSequence[_T_co]]: ...
304304
def __reversed__(self, /) -> Iterator[_T_co | NestedSequence[_T_co]]: ...
305305

306306

307+
_T = TypeVar("_T")
308+
NestedDict = dict[str, "NestedDict[_T] | _T"]
309+
310+
307311
AnyStr_co = TypeVar("AnyStr_co", str, bytes, covariant=True)
308312

309313

xarray/core/utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@
9797
)
9898

9999
if TYPE_CHECKING:
100-
from xarray.core.types import Dims, ErrorOptionsWithWarn
100+
from xarray.core.types import Dims, ErrorOptionsWithWarn, NestedDict
101101

102102
K = TypeVar("K")
103103
V = TypeVar("V")
@@ -335,6 +335,25 @@ def remove_incompatible_items(
335335
del first_dict[k]
336336

337337

338+
def flat_items(
339+
nested: Mapping[str, NestedDict[T] | T],
340+
prefix: str | None = None,
341+
separator: str = "/",
342+
) -> Iterable[tuple[str, T]]:
343+
"""Yields flat items from a nested dictionary of dicts.
344+
345+
Notes:
346+
- Only dict subclasses are flattened.
347+
- Duplicate items are not removed. These should be checked separately.
348+
"""
349+
for key, value in nested.items():
350+
key = prefix + separator + key if prefix is not None else key
351+
if isinstance(value, dict):
352+
yield from flat_items(value, key, separator)
353+
else:
354+
yield key, value
355+
356+
338357
def is_full_slice(value: Any) -> bool:
339358
return isinstance(value, slice) and value == slice(None)
340359

0 commit comments

Comments
 (0)