44
55import numpy as np
66import xarray as xr
7+ from modflow_devtools .dfn .schema .v2 import FieldType
78from numpy .typing import NDArray
89
910from flopy4 .mf6 .constants import FILL_DNODATA
1011
1112
12- def _is_keystring_format (dataset : xr .Dataset ) -> bool :
13- """Check if dataset should use keystring format based on metadata."""
14- field_metadata = dataset .attrs .get ("field_metadata" , {})
15- return any (meta .get ("format" ) == "keystring" for meta in field_metadata .values ())
13+ def field_type (value : Any ) -> FieldType :
14+ """Get a value's type according to the MF6 specification."""
1615
17-
18- def _is_tabular_time_format (dataset : xr .Dataset ) -> bool :
19- """True if a dataset has multiple columns and only one dimension 'nper'."""
20- return len (dataset .data_vars ) > 1 and all (
21- "nper" in var .dims and len (var .dims ) == 1 for var in dataset .data_vars .values ()
22- )
23-
24-
25- def is_dataset (value : Any ) -> bool :
26- return isinstance (value , xr .Dataset )
27-
28-
29- def field_format (value : Any ) -> str :
30- """
31- Get a field's formatting type as defined by the MF6 definition language:
32- https://modflow6.readthedocs.io/en/stable/_dev/dfn.html#variable-types
33- """
3416 if isinstance (value , bool ):
3517 return "keyword"
3618 if isinstance (value , int ):
3719 return "integer"
3820 if isinstance (value , float ):
39- return "double precision "
21+ return "double"
4022 if isinstance (value , str ):
4123 return "string"
42- if isinstance (value , ( dict , tuple ) ):
24+ if isinstance (value , tuple ):
4325 return "record"
4426 if isinstance (value , xr .DataArray ):
4527 if value .dtype == "object" :
4628 return "list"
4729 return "array"
48- if isinstance (value , (xr .Dataset , list )):
49- if isinstance (value , xr .Dataset ):
50- if _is_keystring_format (value ):
51- return "keystring"
52- if _is_tabular_time_format (value ):
53- return "list"
30+ if isinstance (value , (list , dict , xr .Dataset )):
5431 return "list"
55- return "keystring"
56-
57-
58- def has_time_dim (value : Any ) -> bool :
59- return isinstance (value , xr .DataArray ) and "nper" in value .dims
32+ raise ValueError (f"Unsupported field type: { type (value )} " )
6033
6134
6235def array_how (value : xr .DataArray ) -> str :
@@ -140,38 +113,42 @@ def array2string(value: NDArray) -> str:
140113 return buffer .getvalue ().strip ()
141114
142115
143- def nonempty (arr : NDArray | xr .DataArray ) -> NDArray :
144- if isinstance (arr , xr .DataArray ):
145- arr = arr .values
146- if arr .dtype == "object" :
147- mask = arr != None # noqa: E711
116+ def nonempty (value : NDArray | xr .DataArray ) -> NDArray :
117+ """
118+ Return a boolean mask of non-empty (non-nodata) values in an array.
119+ TODO: don't hardcode FILL_DNODATA, support different fill values
120+ """
121+ if isinstance (value , xr .DataArray ):
122+ value = value .values
123+ if value .dtype == "object" :
124+ mask = value != None # noqa: E711
148125 else :
149- mask = ~ np .ma .masked_invalid (arr ).mask
150- mask = mask & (arr != FILL_DNODATA )
126+ mask = ~ np .ma .masked_invalid (value ).mask
127+ mask = mask & (value != FILL_DNODATA )
151128 return mask
152129
153130
154- def data2list (value : list | xr .DataArray | xr .Dataset ):
131+ def data2list (value : list | tuple | dict | xr .Dataset | xr .DataArray ):
155132 """
156- Yield record tuples from a list, `DataArray` or `Dataset`.
157-
158- Yields
159- ------
160- tuple
161- Tuples of (*cellid, *values) or (*values) depending on spatial dimensions
133+ Yield records (tuples) from data in a `list`, `dict`, `DataArray` or `Dataset`.
162134 """
163135
164- if isinstance (value , list ):
165- for item in value :
166- yield item
136+ if isinstance (value , (list , tuple )):
137+ for rec in value :
138+ yield rec
139+ return
140+
141+ if isinstance (value , dict ):
142+ for name , val in value .values ():
143+ yield (name , val )
167144 return
168145
169146 if isinstance (value , xr .Dataset ):
170147 yield from dataset2list (value )
171148 return
172149
173- # handle scalar
174- if value .ndim == 0 :
150+ # otherwise we have a DataArray
151+ if value .ndim == 0 : # handle scalar
175152 if not np .isnan (value .item ()) and value .item () is not None :
176153 yield (value .item (),)
177154 return
@@ -184,90 +161,67 @@ def data2list(value: list | xr.DataArray | xr.Dataset):
184161 for i , val in enumerate (values ):
185162 if has_spatial_dims :
186163 cellid = tuple (idx [i ] + 1 for idx in indices )
187- result = cellid + (val ,)
164+ rec = cellid + (val ,)
188165 else :
189- result = (val ,)
190- yield result
166+ rec = (val ,)
167+ yield rec
191168
192169
193170def dataset2list (value : xr .Dataset ):
194171 """
195- Yield record tuples from an xarray Dataset. For regular/tabular list-based format .
172+ Yield records ( tuples) from an ` xarray.Dataset` .
196173
197- Yields
198- ------
199- tuple
200- Tuples of (*cellid, *values) or (*values) depending on spatial dimensions
174+ If the first data variable is a string type, assume all are
175+ string type. Then the dataset represents a keystring; yield
176+ tuples of (name, *value). Otherwise, yield tuples: (*value)
177+ if no spatial dimensions, or (*cellid, *value) when spatial
178+ dimensions are present.
201179 """
202180 if value is None or not any (value .data_vars ):
203181 return
204182
205- # handle scalar
206- first_arr = next (iter (value .data_vars .values ()))
207- if first_arr .ndim == 0 :
208- field_vals = []
209- for field_name in value .data_vars .keys ():
210- field_val = value [field_name ]
211- if hasattr (field_val , "item" ):
212- field_vals .append (field_val .item ())
213- else :
214- field_vals .append (field_val )
215- yield tuple (field_vals )
183+ first = next (iter (value .data_vars .values ()))
184+ is_union = first .dtype .type is np .str_
185+
186+ if first .ndim == 0 : # handle scalar
187+ if is_union :
188+ for name in value .data_vars .keys ():
189+ val = value [name ]
190+ val = val .item () if val .shape == () else val
191+ yield (* name .split ("_" ), val )
192+ else :
193+ vals = []
194+ for name in value .data_vars .keys ():
195+ val = value [name ]
196+ val = val .item () if val .shape == () else val
197+ vals .append (val )
198+ yield tuple (vals )
216199 return
217200
218- # build mask
219201 combined_mask : Any = None
220- for field_name , arr in value .data_vars .items ():
221- mask = nonempty (arr )
202+ for name , first in value .data_vars .items ():
203+ mask = nonempty (first )
222204 combined_mask = mask if combined_mask is None else combined_mask | mask
223205 if combined_mask is None or not np .any (combined_mask ):
224206 return
225207
226- spatial_dims = [d for d in first_arr .dims if d in ("nlay" , "nrow" , "ncol" , "nodes" )]
208+ spatial_dims = [d for d in first .dims if d in ("nlay" , "nrow" , "ncol" , "nodes" )]
227209 has_spatial_dims = len (spatial_dims ) > 0
228210 indices = np .where (combined_mask )
229211 for i in range (len (indices [0 ])):
230- field_vals = []
231- for field_name in value .data_vars .keys ():
232- field_val = value [field_name ][tuple (idx [i ] for idx in indices )]
233- if hasattr (field_val , "item" ):
234- field_vals .append (field_val .item ())
235- else :
236- field_vals .append (field_val )
237- if has_spatial_dims :
238- cellid = tuple (idx [i ] + 1 for idx in indices )
239- yield cellid + tuple (field_vals )
212+ if is_union :
213+ for name in value .data_vars .keys ():
214+ val = value [name ][tuple (idx [i ] for idx in indices )]
215+ val = val .item () if val .shape == () else val
216+ yield (* name .split ("_" ), val )
240217 else :
241- yield tuple (field_vals )
242-
243-
244- def data2keystring (value : dict | xr .Dataset ):
245- """
246- Yield record tuples from a dict or dataset. For irregular list-based format, i.e. keystrings.
247-
248- Yields
249- ------
250- tuple
251- Tuples of (field_name, value) for use with record macro
252- """
253- if isinstance (value , dict ):
254- if not value :
255- return
256- for field_name , field_val in value .items ():
257- yield (field_name .upper (), field_val )
258- elif isinstance (value , xr .Dataset ):
259- if value is None or not any (value .data_vars ):
260- return
261-
262- for field_name in value .data_vars .keys ():
263- name = (
264- field_name .replace ("_" , " " ).upper ()
265- if np .issubdtype (value .data_vars [field_name ].dtype , np .str_ )
266- else field_name .upper ()
267- )
268- field_val = value [field_name ]
269- if hasattr (field_val , "item" ):
270- val = field_val .item ()
218+ vals = []
219+ for name in value .data_vars .keys ():
220+ val = value [name ][tuple (idx [i ] for idx in indices )]
221+ val = val .item () if val .shape == () else val
222+ vals .append (val )
223+ if has_spatial_dims :
224+ cellid = tuple (idx [i ] + 1 for idx in indices )
225+ yield cellid + tuple (vals )
271226 else :
272- val = field_val
273- yield (name , val )
227+ yield tuple (vals )
0 commit comments