99import xattree
1010from attrs import define
1111from cattrs import Converter
12- from modflow_devtools .dfn .schema .block import block_sort_key
1312from numpy .typing import NDArray
1413from xattree import get_xatspec
1514
@@ -88,7 +87,7 @@ def _path_to_tuple(field_name: str, path_value: Path) -> tuple:
8887 return (field_name .upper (), "FILEOUT" , str (path_value ))
8988
9089
91- def _user_dims (value , field_value ):
90+ def _hack_spatial_dims (value , field_value ):
9291 # terrible hack to convert flat nodes dimension to 3d structured dims.
9392 # long term solution for this is to use a custom xarray index. filters
9493 # should then have access to all dimensions needed.
@@ -162,15 +161,18 @@ def _get_binding_blocks(value: Component) -> dict[str, dict[str, list[tuple]]]:
162161 if block_name not in blocks :
163162 blocks [block_name ] = {}
164163 blocks [block_name ][name ] = [
165- _Binding .from_component (comp ).to_tuple ()
166- for comp in child
167- if comp is not None
164+ _Binding .from_component (comp ).to_tuple () for comp in child if comp is not None
168165 ]
169166 case _:
170167 raise ValueError (f"Unexpected child type: { type (child )} " )
171168
172169 return blocks
173170
171+
172+ def _has_spatial_dims (value : xr .DataArray ) -> bool :
173+ return any (dim in value .dims for dim in ["nlay" , "nrow" , "ncol" , "nodes" ])
174+
175+
174176def unstructure_component (value : Component ) -> dict [str , Any ]:
175177 dfnspec = value .dfn
176178 xatspec = xattree .get_xatspec (type (value ))
@@ -183,13 +185,13 @@ def unstructure_component(value: Component) -> dict[str, Any]:
183185
184186 for field_name in block .keys ():
185187 # Skip child components that have been processed as bindings
186- if isinstance (value , Context ) and field_name in xatspec .children :
187- child_spec = xatspec .children [field_name ]
188- if hasattr (child_spec , "metadata" ) and "block" in child_spec .metadata : # type: ignore
189- if child_spec .metadata ["block" ] == block_name : # type: ignore
190- continue
188+ if (
189+ isinstance (value , Context )
190+ and (child_spec := xatspec .children .get (field_name , None ))
191+ and child_spec .metadata ["block" ] == block_name
192+ ):
193+ continue
191194
192- field_value = data [field_name ]
193195 # convert:
194196 # - bools to keywords
195197 # - paths to records
@@ -198,73 +200,74 @@ def unstructure_component(value: Component) -> dict[str, Any]:
198200 # - xarray DataArrays with 'nper' dimension to kper-sliced datasets
199201 # (and split the period data into separate kper-indexed blocks)
200202 # - other values to their original form
201- if isinstance (field_value , bool ):
202- if field_value : # only write if true
203- blocks [block_name ][field_name ] = field_value
204- if isinstance (field_value , Path ):
205- rec = _path_to_tuple (field_name , field_value )
206- field_name = rec [0 ] # '_file' suffix dropped
207- blocks [block_name ][field_name ] = rec
208- elif isinstance (field_value , datetime ):
209- blocks [block_name ][field_name ] = field_value .isoformat ()
210- elif isinstance (field_value , xr .DataArray ):
211- if field_name == "auxiliary" :
212- blocks [block_name ][field_name ] = tuple (field_value .values .tolist ())
213- elif "nper" not in field_value .dims :
214- blocks [block_name ][field_name ] = _user_dims (value , field_value )
215- else :
216- period_data = {}
217- period_blocks = {}
218- has_spatial_dims = any (
219- dim in field_value .dims for dim in ["nlay" , "nrow" , "ncol" , "nodes" ]
220- )
221- if has_spatial_dims :
222- field_value = _user_dims (value , field_value )
223-
224- period_data [field_name ] = {
225- kper : field_value .isel (nper = kper )
226- for kper in range (field_value .sizes ["nper" ])
227- }
203+ match field_value := data [field_name ]:
204+ case None :
205+ pass
206+ case bool ():
207+ if field_value : # only write if true
208+ blocks [block_name ][field_name ] = field_value
209+ case Path ():
210+ rec = _path_to_tuple (field_name , field_value )
211+ field_name = rec [0 ] # '_file' suffix dropped
212+ blocks [block_name ][field_name ] = rec
213+ case datetime ():
214+ blocks [block_name ][field_name ] = field_value .isoformat ()
215+ case xr .DataArray ():
216+ if field_name == "auxiliary" :
217+ blocks [block_name ][field_name ] = tuple (field_value .values .tolist ())
218+ elif "nper" not in field_value .dims :
219+ blocks [block_name ][field_name ] = _hack_spatial_dims (value , field_value )
228220 else :
229- if np .issubdtype (field_value .dtype , np .str_ ):
221+ period_data = {}
222+ period_blocks = {}
223+ if _has_spatial_dims (field_value ):
224+ field_value = _hack_spatial_dims (value , field_value )
230225 period_data [field_name ] = {
231- kper : field_value [ kper ]
226+ kper : field_value . isel ( nper = kper )
232227 for kper in range (field_value .sizes ["nper" ])
233- if field_value [kper ] is not None
234228 }
235229 else :
236- if block_name not in period_data :
237- period_data [block_name ] = {}
238- period_data [block_name ][field_name ] = field_value # type: ignore
239-
240- dataset = xr .Dataset (period_data [block_name ])
241- _attach_field_metadata (dataset , type (value ), list (period_data [block_name ].keys ())) # type: ignore
242- blocks [block_name ] = {block_name : dataset }
243- del period_data [block_name ]
244-
245- for arr_name , periods in period_data .items ():
246- for kper , arr in periods .items ():
247- if isinstance (arr , xr .DataArray ):
248- max = arr .max ()
249- if max == arr .min () and max == FILL_DNODATA :
250- # don't write empty period blocks unless
251- # to intentionally reset data
252- pass
230+ if np .issubdtype (field_value .dtype , np .str_ ):
231+ period_data [field_name ] = {
232+ kper : field_value [kper ]
233+ for kper in range (field_value .sizes ["nper" ])
234+ if field_value [kper ] is not None
235+ }
236+ else :
237+ if block_name not in period_data :
238+ period_data [block_name ] = {}
239+ period_data [block_name ][field_name ] = field_value # type: ignore
240+
241+ dataset = xr .Dataset (period_data [block_name ])
242+ _attach_field_metadata (
243+ dataset , type (value ), list (period_data [block_name ].keys ())
244+ ) # type: ignore
245+ blocks [block_name ] = {block_name : dataset }
246+ del period_data [block_name ]
247+
248+ for arr_name , periods in period_data .items ():
249+ for kper , arr in periods .items ():
250+ if isinstance (arr , xr .DataArray ):
251+ max = arr .max ()
252+ if max == arr .min () and max == FILL_DNODATA :
253+ # don't write empty period blocks unless
254+ # to intentionally reset data
255+ pass
256+ else :
257+ if kper not in period_blocks :
258+ period_blocks [kper ] = {}
259+ period_blocks [kper ][arr_name ] = arr
253260 else :
254261 if kper not in period_blocks :
255262 period_blocks [kper ] = {}
256- period_blocks [kper ][arr_name ] = arr
257- else :
258- if kper not in period_blocks :
259- period_blocks [kper ] = {}
260- period_blocks [kper ][arr_name ] = arr .upper ()
261-
262- for kper , block in period_blocks .items ():
263- dataset = xr .Dataset (block )
264- _attach_field_metadata (dataset , type (value ), list (block .keys ()))
265- blocks [f"{ block_name } { kper + 1 } " ] = {block_name : dataset }
266- elif field_value is not None :
267- blocks [block_name ][field_name ] = field_value
263+ period_blocks [kper ][arr_name ] = arr .upper ()
264+
265+ for kper , block in period_blocks .items ():
266+ dataset = xr .Dataset (block )
267+ _attach_field_metadata (dataset , type (value ), list (block .keys ()))
268+ blocks [f"{ block_name } { kper + 1 } " ] = {block_name : dataset }
269+ case _:
270+ blocks [block_name ][field_name ] = field_value
268271
269272 # make sure options block always comes first
270273 # TODO: blocks should already be sorted here
0 commit comments