44
55import numpy as np
66import xarray as xr
7+ from modflow_devtools .dfn .schema .v2 import FieldType
78from numpy .typing import NDArray
89
910from flopy4 .mf6 .constants import FILL_DNODATA
1011
1112
12- def _is_keystring_format (dataset : xr .Dataset ) -> bool :
13- """Check if dataset should use keystring format based on metadata."""
14- field_metadata = dataset .attrs .get ("field_metadata" , {})
15- return any (meta .get ("format" ) == "keystring" for meta in field_metadata .values ())
13+ def field_type (value : Any ) -> FieldType :
14+ """Get a value's type according to the MF6 specification."""
1615
17-
18- def _is_tabular_time_format (dataset : xr .Dataset ) -> bool :
19- """True if a dataset has multiple columns and only one dimension 'nper'."""
20- return len (dataset .data_vars ) > 1 and all (
21- "nper" in var .dims and len (var .dims ) == 1 for var in dataset .data_vars .values ()
22- )
23-
24-
25- def is_dataset (value : Any ) -> bool :
26- return isinstance (value , xr .Dataset )
27-
28-
29- def field_format (value : Any ) -> str :
30- """
31- Get a field's formatting type as defined by the MF6 definition language:
32- https://modflow6.readthedocs.io/en/stable/_dev/dfn.html#variable-types
33- """
3416 if isinstance (value , bool ):
3517 return "keyword"
3618 if isinstance (value , int ):
3719 return "integer"
3820 if isinstance (value , float ):
39- return "double precision "
21+ return "double"
4022 if isinstance (value , str ):
4123 return "string"
4224 if isinstance (value , (dict , tuple )):
@@ -45,18 +27,9 @@ def field_format(value: Any) -> str:
4527 if value .dtype == "object" :
4628 return "list"
4729 return "array"
48- if isinstance (value , (xr .Dataset , list )):
49- if isinstance (value , xr .Dataset ):
50- if _is_keystring_format (value ):
51- return "keystring"
52- if _is_tabular_time_format (value ):
53- return "list"
30+ if isinstance (value , (list , xr .Dataset )):
5431 return "list"
55- return "keystring"
56-
57-
58- def has_time_dim (value : Any ) -> bool :
59- return isinstance (value , xr .DataArray ) and "nper" in value .dims
32+ raise ValueError (f"Unsupported field type: { type (value )} " )
6033
6134
6235def array_how (value : xr .DataArray ) -> str :
@@ -140,20 +113,26 @@ def array2string(value: NDArray) -> str:
140113 return buffer .getvalue ().strip ()
141114
142115
143- def nonempty (arr : NDArray | xr .DataArray ) -> NDArray :
144- if isinstance (arr , xr .DataArray ):
145- arr = arr .values
146- if arr .dtype == "object" :
147- mask = arr != None # noqa: E711
116+ def nonempty (value : NDArray | xr .DataArray ) -> NDArray :
117+ """
118+ Return a boolean mask of non-empty (non-nodata) values in an array.
119+ TODO: don't hardcode FILL_DNODATA, support different fill values
120+ """
121+ if isinstance (value , xr .DataArray ):
122+ value = value .values
123+ if value .dtype == "object" :
124+ mask = value != None # noqa: E711
148125 else :
149- mask = ~ np .ma .masked_invalid (arr ).mask
150- mask = mask & (arr != FILL_DNODATA )
126+ mask = ~ np .ma .masked_invalid (value ).mask
127+ mask = mask & (value != FILL_DNODATA )
151128 return mask
152129
153130
154- def data2list (value : list | xr .DataArray | xr .Dataset ):
131+ def data2list (value : list | dict | xr .Dataset | xr .DataArray ):
155132 """
156- Yield record tuples from a list, `DataArray` or `Dataset`.
133+ Yield records (tuples) from data in a `list`, `dict`, `DataArray` or `Dataset`.
134+ Data can be regular or irregular: every item in a `list` is of the same record
135+ type, while items in a `dict` or `Dataset` can be of different types.
157136
158137 Yields
159138 ------
@@ -162,16 +141,21 @@ def data2list(value: list | xr.DataArray | xr.Dataset):
162141 """
163142
164143 if isinstance (value , list ):
165- for item in value :
166- yield item
144+ for rec in value :
145+ yield rec
146+ return
147+
148+ if isinstance (value , dict ):
149+ for name , val in value .items ():
150+ yield (name .upper (), val )
167151 return
168152
169153 if isinstance (value , xr .Dataset ):
170154 yield from dataset2list (value )
171155 return
172156
173- # handle scalar
174- if value .ndim == 0 :
157+ # otherwise we have a DataArray
158+ if value .ndim == 0 : # handle scalar
175159 if not np .isnan (value .item ()) and value .item () is not None :
176160 yield (value .item (),)
177161 return
@@ -184,15 +168,15 @@ def data2list(value: list | xr.DataArray | xr.Dataset):
184168 for i , val in enumerate (values ):
185169 if has_spatial_dims :
186170 cellid = tuple (idx [i ] + 1 for idx in indices )
187- result = cellid + (val ,)
171+ rec = cellid + (val ,)
188172 else :
189- result = (val ,)
190- yield result
173+ rec = (val ,)
174+ yield rec
191175
192176
193177def dataset2list (value : xr .Dataset ):
194178 """
195- Yield record tuples from an xarray Dataset. For regular/tabular list-based format.
179+ Yield record tuples from an ` xarray. Dataset` . For regular/tabular list-based format.
196180
197181 Yields
198182 ------
@@ -202,72 +186,36 @@ def dataset2list(value: xr.Dataset):
202186 if value is None or not any (value .data_vars ):
203187 return
204188
205- # handle scalar
206- first_arr = next (iter (value .data_vars .values ()))
207- if first_arr .ndim == 0 :
208- field_vals = []
209- for field_name in value .data_vars .keys ():
210- field_val = value [field_name ]
211- if hasattr (field_val , "item" ):
212- field_vals .append (field_val .item ())
213- else :
214- field_vals .append (field_val )
215- yield tuple (field_vals )
189+ first = next (iter (value .data_vars .values ()))
190+ if first .ndim == 0 : # handle scalar
191+ vals = []
192+ for name in value .data_vars .keys ():
193+ val = value [name ]
194+ val = val .item () if val .shape == () else val
195+ vals .append (val )
196+ yield tuple (vals )
216197 return
217198
218- # build mask
219199 combined_mask : Any = None
220- for field_name , arr in value .data_vars .items ():
221- mask = nonempty (arr )
200+ for name , first in value .data_vars .items ():
201+ mask = nonempty (first )
222202 combined_mask = mask if combined_mask is None else combined_mask | mask
223203 if combined_mask is None or not np .any (combined_mask ):
224204 return
225205
226- spatial_dims = [d for d in first_arr .dims if d in ("nlay" , "nrow" , "ncol" , "nodes" )]
206+ spatial_dims = [d for d in first .dims if d in ("nlay" , "nrow" , "ncol" , "nodes" )]
227207 has_spatial_dims = len (spatial_dims ) > 0
228208 indices = np .where (combined_mask )
229209 for i in range (len (indices [0 ])):
230- field_vals = []
231- for field_name in value .data_vars .keys ():
232- field_val = value [field_name ][tuple (idx [i ] for idx in indices )]
233- if hasattr (field_val , "item" ):
234- field_vals .append (field_val .item ())
210+ vals = []
211+ for name in value .data_vars .keys ():
212+ val = value [name ][tuple (idx [i ] for idx in indices )]
213+ if hasattr (val , "item" ):
214+ vals .append (val .item ())
235215 else :
236- field_vals .append (field_val )
216+ vals .append (val )
237217 if has_spatial_dims :
238218 cellid = tuple (idx [i ] + 1 for idx in indices )
239- yield cellid + tuple (field_vals )
219+ yield cellid + tuple (vals )
240220 else :
241- yield tuple (field_vals )
242-
243-
244- def data2keystring (value : dict | xr .Dataset ):
245- """
246- Yield record tuples from a dict or dataset. For irregular list-based format, i.e. keystrings.
247-
248- Yields
249- ------
250- tuple
251- Tuples of (field_name, value) for use with record macro
252- """
253- if isinstance (value , dict ):
254- if not value :
255- return
256- for field_name , field_val in value .items ():
257- yield (field_name .upper (), field_val )
258- elif isinstance (value , xr .Dataset ):
259- if value is None or not any (value .data_vars ):
260- return
261-
262- for field_name in value .data_vars .keys ():
263- name = (
264- field_name .replace ("_" , " " ).upper ()
265- if np .issubdtype (value .data_vars [field_name ].dtype , np .str_ )
266- else field_name .upper ()
267- )
268- field_val = value [field_name ]
269- if hasattr (field_val , "item" ):
270- val = field_val .item ()
271- else :
272- val = field_val
273- yield (name , val )
221+ yield tuple (vals )
0 commit comments