4
4
import io
5
5
import itertools
6
6
import textwrap
7
- from collections import ChainMap
7
+ from collections import ChainMap , defaultdict
8
8
from collections .abc import (
9
9
Callable ,
10
10
Hashable ,
11
11
Iterable ,
12
12
Iterator ,
13
13
Mapping ,
14
14
)
15
+ from dataclasses import dataclass , field
15
16
from html import escape
16
17
from os import PathLike
17
18
from typing import (
21
22
Literal ,
22
23
NoReturn ,
23
24
ParamSpec ,
25
+ TypeAlias ,
24
26
TypeVar ,
25
27
Union ,
26
28
overload ,
85
87
DtCompatible ,
86
88
ErrorOptions ,
87
89
ErrorOptionsWithWarn ,
90
+ NestedDict ,
88
91
NetcdfWriteModes ,
89
92
T_ChunkDimFreq ,
90
93
T_ChunksFreq ,
@@ -441,6 +444,20 @@ def map( # type: ignore[override]
441
444
return Dataset (variables , attrs = attrs )
442
445
443
446
447
+ FromDictDataValue : TypeAlias = "CoercibleValue | Dataset | DataTree | None"
448
+
449
+
450
+ @dataclass
451
+ class _CoordWrapper :
452
+ value : CoercibleValue
453
+
454
+
455
+ @dataclass
456
+ class _DatasetArgs :
457
+ data_vars : dict [str , CoercibleValue ] = field (default_factory = dict )
458
+ coords : dict [str , CoercibleValue ] = field (default_factory = dict )
459
+
460
+
444
461
class DataTree (
445
462
NamedNode ,
446
463
DataTreeAggregations ,
@@ -1154,51 +1171,215 @@ def drop_nodes(
1154
1171
result ._replace_node (children = children_to_keep )
1155
1172
return result
1156
1173
1174
+ @overload
1175
+ @classmethod
1176
+ def from_dict (
1177
+ cls ,
1178
+ data : Mapping [str , FromDictDataValue ] | None = ...,
1179
+ coords : Mapping [str , CoercibleValue ] | None = ...,
1180
+ * ,
1181
+ name : str | None = ...,
1182
+ nested : Literal [False ] = ...,
1183
+ ) -> Self : ...
1184
+
1185
+ @overload
1186
+ @classmethod
1187
+ def from_dict (
1188
+ cls ,
1189
+ data : (
1190
+ Mapping [str , FromDictDataValue | NestedDict [FromDictDataValue ]] | None
1191
+ ) = ...,
1192
+ coords : Mapping [str , CoercibleValue | NestedDict [CoercibleValue ]] | None = ...,
1193
+ * ,
1194
+ name : str | None = ...,
1195
+ nested : Literal [True ] = ...,
1196
+ ) -> Self : ...
1197
+
1157
1198
@classmethod
1158
1199
def from_dict (
1159
1200
cls ,
1160
- d : Mapping [str , Dataset | DataTree | None ],
1161
- / ,
1201
+ data : (
1202
+ Mapping [str , FromDictDataValue | NestedDict [FromDictDataValue ]] | None
1203
+ ) = None ,
1204
+ coords : Mapping [str , CoercibleValue | NestedDict [CoercibleValue ]] | None = None ,
1205
+ * ,
1162
1206
name : str | None = None ,
1207
+ nested : bool = False ,
1163
1208
) -> Self :
1164
1209
"""
1165
1210
Create a datatree from a dictionary of data objects, organised by paths into the tree.
1166
1211
1167
1212
Parameters
1168
1213
----------
1169
- d : dict-like
1170
- A mapping from path names to xarray.Dataset or DataTree objects.
1214
+ data : dict-like, optional
1215
+ A mapping from path names to ``None`` (indicating an empty node),
1216
+ ``DataTree``, ``Dataset``, objects coercible into a ``DataArray`` or
1217
+ a nested dictionary of any of the above types.
1171
1218
1172
- Path names are to be given as unix-like path. If path names
1173
- containing more than one part are given, new tree nodes will be
1174
- constructed as necessary.
1219
+ Path names should be given as unix-like paths, either absolute
1220
+ (/path/to/item) or relative to the root node (path/to/item). If path
1221
+ names containing more than one part are given, new tree nodes will
1222
+ be constructed automatically as necessary.
1175
1223
1176
1224
To assign data to the root node of the tree use "", ".", "/" or "./"
1177
1225
as the path.
1226
+ coords : dict-like, optional
1227
+ A mapping from path names to objects coercible into a DataArray, or
1228
+ nested dictionaries of coercible objects.
1178
1229
name : Hashable | None, optional
1179
1230
Name for the root node of the tree. Default is None.
1231
+ nested : bool, optional
1232
+ If true, nested dictionaries in ``data`` and ``coords`` are
1233
+ automatically flattened.
1180
1234
1181
1235
Returns
1182
1236
-------
1183
1237
DataTree
1184
1238
1239
+ See also
1240
+ --------
1241
+ Dataset
1242
+
1185
1243
Notes
1186
1244
-----
1187
- If your dictionary is nested you will need to flatten it before using this method.
1188
- """
1189
- # Find any values corresponding to the root
1190
- d_cast = dict (d )
1191
- root_data = None
1192
- for key in ("" , "." , "/" , "./" ):
1193
- if key in d_cast :
1194
- if root_data is not None :
1245
+ ``DataTree.from_dict`` serves a conceptually different purpose from
1246
+ ``Dataset.from_dict`` and ``DataArray.from_dict``. It converts a
1247
+ hierarchy of Xarray objects into a DataTree, rather than converting pure
1248
+ Python data structures.
1249
+
1250
+ Examples
1251
+ --------
1252
+
1253
+ Construct a tree from a dict of Dataset objects:
1254
+
1255
+ >>> dt = DataTree.from_dict(
1256
+ ... {
1257
+ ... "/": Dataset(coords={"time": [1, 2, 3]}),
1258
+ ... "/ocean": Dataset(
1259
+ ... {
1260
+ ... "temperature": ("time", [4, 5, 6]),
1261
+ ... "salinity": ("time", [7, 8, 9]),
1262
+ ... }
1263
+ ... ),
1264
+ ... "/atmosphere": Dataset(
1265
+ ... {
1266
+ ... "temperature": ("time", [2, 3, 4]),
1267
+ ... "humidity": ("time", [3, 4, 5]),
1268
+ ... }
1269
+ ... ),
1270
+ ... }
1271
+ ... )
1272
+ >>> dt
1273
+ <xarray.DataTree>
1274
+ Group: /
1275
+ │ Dimensions: (time: 3)
1276
+ │ Coordinates:
1277
+ │ * time (time) int64 24B 1 2 3
1278
+ ├── Group: /ocean
1279
+ │ Dimensions: (time: 3)
1280
+ │ Data variables:
1281
+ │ temperature (time) int64 24B 4 5 6
1282
+ │ salinity (time) int64 24B 7 8 9
1283
+ └── Group: /atmosphere
1284
+ Dimensions: (time: 3)
1285
+ Data variables:
1286
+ temperature (time) int64 24B 2 3 4
1287
+ humidity (time) int64 24B 3 4 5
1288
+
1289
+ Or equivalently, use a dict of values that can be converted into
1290
+ `DataArray` objects, with syntax similar to the Dataset constructor:
1291
+
1292
+ >>> dt2 = DataTree.from_dict(
1293
+ ... data={
1294
+ ... "/ocean/temperature": ("time", [4, 5, 6]),
1295
+ ... "/ocean/salinity": ("time", [7, 8, 9]),
1296
+ ... "/atmosphere/temperature": ("time", [2, 3, 4]),
1297
+ ... "/atmosphere/humidity": ("time", [3, 4, 5]),
1298
+ ... },
1299
+ ... coords={"/time": [1, 2, 3]},
1300
+ ... )
1301
+ >>> assert dt.identical(dt2)
1302
+
1303
+ Nested dictionaries are automatically flattened if ``nested=True``:
1304
+
1305
+ >>> DataTree.from_dict({"a": {"b": {"c": {"x": 1, "y": 2}}}}, nested=True)
1306
+ <xarray.DataTree>
1307
+ Group: /
1308
+ └── Group: /a
1309
+ └── Group: /a/b
1310
+ └── Group: /a/b/c
1311
+ Dimensions: ()
1312
+ Data variables:
1313
+ x int64 8B 1
1314
+ y int64 8B 2
1315
+
1316
+ """
1317
+ if data is None :
1318
+ data = {}
1319
+
1320
+ if coords is None :
1321
+ coords = {}
1322
+
1323
+ if nested :
1324
+ data_items = utils .flat_items (data )
1325
+ coords_items = utils .flat_items (coords )
1326
+ else :
1327
+ data_items = data .items ()
1328
+ coords_items = coords .items ()
1329
+ for arg_name , items in [("data" , data_items ), ("coords" , coords_items )]:
1330
+ for key , value in items :
1331
+ if isinstance (value , dict ):
1332
+ raise TypeError (
1333
+ f"{ arg_name } contains a dict value at { key = } , "
1334
+ "which is not a valid argument to "
1335
+ f"DataTree.from_dict() with nested=False: { value } "
1336
+ )
1337
+
1338
+ # Canonicalize and unify paths between `data` and `coords`
1339
+ flat_data_and_coords = itertools .chain (
1340
+ data_items ,
1341
+ ((k , _CoordWrapper (v )) for k , v in coords_items ),
1342
+ )
1343
+ nodes : dict [NodePath , _CoordWrapper | FromDictDataValue ] = {}
1344
+ for key , value in flat_data_and_coords :
1345
+ path = NodePath (key ).absolute ()
1346
+ if path in nodes :
1347
+ raise ValueError (
1348
+ f"multiple entries found corresponding to node { str (path )!r} "
1349
+ )
1350
+ nodes [path ] = value
1351
+
1352
+ # Merge nodes corresponding to DataArrays into Datasets
1353
+ dataset_args : defaultdict [NodePath , _DatasetArgs ] = defaultdict (_DatasetArgs )
1354
+ for path in list (nodes ):
1355
+ node = nodes [path ]
1356
+ if node is not None and not isinstance (node , Dataset | DataTree ):
1357
+ if path .parent == path :
1358
+ raise ValueError ("cannot set DataArray value at root" )
1359
+ if path .parent in nodes :
1195
1360
raise ValueError (
1196
- "multiple entries found corresponding to the root node"
1361
+ f"cannot set DataArray value at { str (path )!r} when "
1362
+ f"parent node at { str (path .parent )!r} is also set"
1197
1363
)
1198
- root_data = d_cast .pop (key )
1364
+ del nodes [path ]
1365
+ if isinstance (node , _CoordWrapper ):
1366
+ dataset_args [path .parent ].coords [path .name ] = node .value
1367
+ else :
1368
+ dataset_args [path .parent ].data_vars [path .name ] = node
1369
+ for path , args in dataset_args .items ():
1370
+ try :
1371
+ nodes [path ] = Dataset (args .data_vars , args .coords )
1372
+ except (ValueError , TypeError ) as e :
1373
+ raise type (e )(
1374
+ "failed to construct xarray.Dataset for DataTree node at "
1375
+ f"{ str (path )!r} with data_vars={ args .data_vars } and "
1376
+ f"coords={ args .coords } "
1377
+ ) from e
1199
1378
1200
1379
# Create the root node
1201
- if isinstance (root_data , DataTree ):
1380
+ root_data = nodes .pop (NodePath ("/" ), None )
1381
+ if isinstance (root_data , cls ):
1382
+ # use cls so type-checkers understand this method returns Self
1202
1383
obj = root_data .copy ()
1203
1384
obj .name = name
1204
1385
elif root_data is None or isinstance (root_data , Dataset ):
@@ -1209,31 +1390,29 @@ def from_dict(
1209
1390
f"or DataTree, got { type (root_data )} "
1210
1391
)
1211
1392
1212
- def depth (item ) -> int :
1213
- pathstr , _ = item
1214
- return len (NodePath ( pathstr ) .parts )
1393
+ def depth (item : tuple [ NodePath , object ] ) -> int :
1394
+ node_path , _ = item
1395
+ return len (node_path .parts )
1215
1396
1216
- if d_cast :
1217
- # Populate tree with children determined from data_objects mapping
1397
+ if nodes :
1398
+ # Populate tree with children
1218
1399
# Sort keys by depth so as to insert nodes from root first (see GH issue #9276)
1219
- for path , data in sorted (d_cast .items (), key = depth ):
1400
+ for path , node in sorted (nodes .items (), key = depth ):
1220
1401
# Create and set new node
1221
- if isinstance (data , DataTree ):
1222
- new_node = data .copy ()
1223
- elif isinstance (data , Dataset ) or data is None :
1224
- new_node = cls (dataset = data )
1402
+ if isinstance (node , DataTree ):
1403
+ new_node = node .copy ()
1404
+ elif isinstance (node , Dataset ) or node is None :
1405
+ new_node = cls (dataset = node )
1225
1406
else :
1226
- raise TypeError (f"invalid values: { data } " )
1407
+ raise TypeError (f"invalid values: { node } " )
1227
1408
obj ._set_item (
1228
1409
path ,
1229
1410
new_node ,
1230
1411
allow_overwrite = False ,
1231
1412
new_nodes_along_path = True ,
1232
1413
)
1233
1414
1234
- # TODO: figure out why mypy is raising an error here, likely something
1235
- # to do with the return type of Dataset.copy()
1236
- return obj # type: ignore[return-value]
1415
+ return obj
1237
1416
1238
1417
def to_dict (self , relative : bool = False ) -> dict [str , Dataset ]:
1239
1418
"""
0 commit comments