4
4
import sys
5
5
from itertools import repeat
6
6
from textwrap import dedent
7
- from typing import TYPE_CHECKING , Callable , Tuple
7
+ from typing import TYPE_CHECKING , Callable
8
8
9
9
from xarray import DataArray , Dataset
10
-
11
10
from xarray .core .iterators import LevelOrderIter
12
11
from xarray .core .treenode import NodePath , TreeNode
13
12
@@ -84,14 +83,13 @@ def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> s
84
83
for node_a , node_b in zip (LevelOrderIter (a ), LevelOrderIter (b )):
85
84
path_a , path_b = node_a .path , node_b .path
86
85
87
- if require_names_equal :
88
- if node_a .name != node_b .name :
89
- diff = dedent (
90
- f"""\
86
+ if require_names_equal and node_a .name != node_b .name :
87
+ diff = dedent (
88
+ f"""\
91
89
Node '{ path_a } ' in the left object has name '{ node_a .name } '
92
90
Node '{ path_b } ' in the right object has name '{ node_b .name } '"""
93
- )
94
- return diff
91
+ )
92
+ return diff
95
93
96
94
if len (node_a .children ) != len (node_b .children ):
97
95
diff = dedent (
@@ -125,7 +123,7 @@ def map_over_subtree(func: Callable) -> Callable:
125
123
func : callable
126
124
Function to apply to datasets with signature:
127
125
128
- `func(*args, **kwargs) -> Union[Dataset , Iterable[Dataset ]]`.
126
+ `func(*args, **kwargs) -> Union[DataTree , Iterable[DataTree ]]`.
129
127
130
128
(i.e. func must accept at least one Dataset and return at least one Dataset.)
131
129
Function will not be applied to any nodes without datasets.
@@ -154,7 +152,7 @@ def map_over_subtree(func: Callable) -> Callable:
154
152
# TODO inspect function to work out immediately if the wrong number of arguments were passed for it?
155
153
156
154
@functools .wraps (func )
157
- def _map_over_subtree (* args , ** kwargs ) -> DataTree | Tuple [DataTree , ...]:
155
+ def _map_over_subtree (* args , ** kwargs ) -> DataTree | tuple [DataTree , ...]:
158
156
"""Internal function which maps func over every node in tree, returning a tree of the results."""
159
157
from xarray .core .datatree import DataTree
160
158
@@ -259,19 +257,18 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
259
257
return _map_over_subtree
260
258
261
259
262
- def _handle_errors_with_path_context (path ):
260
+ def _handle_errors_with_path_context (path : str ):
263
261
"""Wraps given function so that if it fails it also raises path to node on which it failed."""
264
262
265
263
def decorator (func ):
266
264
def wrapper (* args , ** kwargs ):
267
265
try :
268
266
return func (* args , ** kwargs )
269
267
except Exception as e :
270
- if sys .version_info >= (3 , 11 ):
271
- # Add the context information to the error message
272
- e .add_note (
273
- f"Raised whilst mapping function over node with path { path } "
274
- )
268
+ # Add the context information to the error message
269
+ add_note (
270
+ e , f"Raised whilst mapping function over node with path { path } "
271
+ )
275
272
raise
276
273
277
274
return wrapper
@@ -287,7 +284,9 @@ def add_note(err: BaseException, msg: str) -> None:
287
284
err .add_note (msg )
288
285
289
286
290
- def _check_single_set_return_values (path_to_node , obj ):
287
+ def _check_single_set_return_values (
288
+ path_to_node : str , obj : Dataset | DataArray | tuple [Dataset | DataArray ]
289
+ ):
291
290
"""Check types returned from single evaluation of func, and return number of return values received from func."""
292
291
if isinstance (obj , (Dataset , DataArray )):
293
292
return 1
0 commit comments