Skip to content

Commit 90a2ab7

Browse files
committed
working
1 parent db6ec10 commit 90a2ab7

File tree

2 files changed

+44
-35
lines changed

2 files changed

+44
-35
lines changed

docs/examples/attrs_xarray_demo.py

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def _to_path(value: Any) -> Optional[Path]:
2020
return Path(value) if value else None
2121

2222

23-
def _parse_shape(shape: str) -> tuple[str]:
23+
def _parse_dim_names(shape: str) -> tuple[str]:
2424
return tuple(
2525
[
2626
dim.strip()
@@ -33,29 +33,36 @@ def _parse_shape(shape: str) -> tuple[str]:
3333
)
3434

3535

36-
def _try_resolve_dim(data: DataTree, name: str) -> int | str:
36+
def _try_resolve_dim(data: Optional[DataTree], name: str) -> int | str:
3737
name = name.strip()
38+
if data is None:
39+
return name
3840
value = data.get(name, None)
3941
if value is not None:
4042
return value.item()
4143
root = data.root
4244
paths = [
43-
"/tdis",
44-
"/gwf/dis",
45+
"tdis",
46+
"dis",
47+
"gwf/dis",
4548
]
4649
for path in paths:
47-
value = root.get(f"{path}/{name}", None)
48-
if value is not None:
49-
return value.item()
50-
print(f"Failed to resolve dim '{name}' for '{data.name}'")
50+
try:
51+
key = f"{path}/{name}"
52+
return root[key].item()
53+
except:
54+
try:
55+
return root[path].dims[name]
56+
except:
57+
pass
5158
return name
5259

5360

5461
def _try_resolve_shape(data: DataTree, attr: Attribute) -> tuple[int | str]:
5562
shape = attr.metadata.get("shape", None)
5663
if shape is None:
5764
raise ValueError(f"Array {attr.name} missing shape metadata")
58-
shape = [_try_resolve_dim(data, dim) for dim in _parse_shape(shape)]
65+
shape = [_try_resolve_dim(data, dim) for dim in _parse_dim_names(shape)]
5966
return shape
6067

6168

@@ -85,19 +92,18 @@ def _resolve_array(
8592
return _reshape_array(value, shape)
8693

8794

88-
def _bind_tree(data: DataTree):
89-
if data.is_root:
90-
return
91-
data.parent = data.parent.assign({data.name: data})
92-
if not data.parent.is_root:
93-
_bind_tree(data.parent)
95+
def _bind_tree(self, parent):
96+
parent.data = parent.data.assign({self.data.name: self.data})
97+
self.data = parent.data[self.data.name]
98+
grandparent = getattr(parent, "parent", None)
99+
if grandparent is not None:
100+
_bind_tree(parent, grandparent)
94101

95102

96-
def _init_tree(self, **kwargs):
103+
def _init_tree(self, parent=None, **kwargs):
97104
cls = type(self)
98105
cls_name = cls.__name__.lower()
99106
spec = fields_dict(cls)
100-
parent = kwargs.get("parent", None)
101107
data = Dataset()
102108
dims = set()
103109

@@ -106,13 +112,15 @@ def _init_tree(self, **kwargs):
106112
value = kwargs.get(name, attr.default)
107113
shape = attr.metadata.get("shape", None)
108114
if shape is not None:
109-
dim_names = [
110-
_try_resolve_dim(parent, dim) for dim in _parse_shape(shape)
115+
dim_names = _parse_dim_names(shape)
116+
shape = [
117+
_try_resolve_dim(parent.data.root if parent else None, dim)
118+
for dim in dim_names
111119
]
112120
shape = tuple(
113121
[
114122
(dim if isinstance(dim, int) else kwargs.get(dim, dim))
115-
for dim in dim_names
123+
for dim in shape
116124
]
117125
)
118126
unresolved = [dim for dim in shape if not isinstance(dim, int)]
@@ -139,6 +147,9 @@ def _init_tree(self, **kwargs):
139147
data[name] = value
140148

141149
self.data = DataTree(data, name=cls_name)
150+
if parent is not None:
151+
self.parent = parent
152+
_bind_tree(self, parent)
142153

143154

144155
def _setattr(self, attr: Attribute, value: Any):
@@ -150,7 +161,7 @@ def _setattr(self, attr: Attribute, value: Any):
150161
return
151162
self.data[attr.name] = (
152163
(
153-
_parse_shape(attr.metadata["shape"]),
164+
_parse_dim_names(attr.metadata["shape"]),
154165
_resolve_array(self, attr, value),
155166
)
156167
if get_origin(attr.type) in [list, np.ndarray]
@@ -161,18 +172,17 @@ def _setattr(self, attr: Attribute, value: Any):
161172

162173
def component(cls):
163174
spec = fields_dict(cls)
164-
init = cls.__init__
165-
166-
def _init(self, *args, **kwargs):
167-
init(self, *args, **kwargs)
168-
_bind_tree(self.data)
169175

170176
def _get(self, name):
171177
if name in spec:
172-
return self.data[name]
178+
value = self.data.get(name, None)
179+
if value is not None:
180+
return value
181+
value = self.data.dims.get(name, None)
182+
if value is not None:
183+
return value
173184
return super(cls, self).__getattribute__(name)
174185

175-
cls.__init__ = _init
176186
cls.__getattribute__ = _get
177187
return cls
178188

@@ -264,7 +274,7 @@ def __init__(
264274
):
265275
_init_tree(
266276
self,
267-
parent=model.data,
277+
parent=model,
268278
length_units=length_units,
269279
nogrb=nogrb,
270280
xorigin=xorigin,
@@ -308,7 +318,7 @@ def __init__(
308318
):
309319
_init_tree(
310320
self,
311-
parent=model.data,
321+
parent=model,
312322
strt=strt,
313323
export_array_ascii=export_array_ascii,
314324
export_array_netcdf=export_array_netcdf,
@@ -369,7 +379,7 @@ def __init__(
369379
):
370380
_init_tree(
371381
self,
372-
parent=model.data,
382+
parent=model,
373383
budget_file=budget_file,
374384
budget_csv_file=budget_csv_file,
375385
head_file=head_file,
@@ -437,7 +447,7 @@ def __init__(
437447
):
438448
_init_tree(
439449
self,
440-
parent=model.data,
450+
parent=model,
441451
icelltype=icelltype,
442452
k=k,
443453
k22=k22,
@@ -456,7 +466,7 @@ def __init__(
456466
self,
457467
sim=None,
458468
):
459-
_init_tree(self, parent=sim.data)
469+
_init_tree(self, parent=sim)
460470

461471

462472
@component
@@ -490,7 +500,7 @@ def __init__(
490500
):
491501
_init_tree(
492502
self,
493-
parent=sim.data,
503+
parent=sim,
494504
nper=nper,
495505
perioddata=perioddata,
496506
time_units=time_units,

uv.lock

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)