@@ -32,13 +32,13 @@ def get_binding_blocks(value: Component) -> dict[str, dict[str, list[tuple[str,
3232 if not isinstance (value , Context ):
3333 return {}
3434
35- blocks = {}
35+ blocks = {} # type: ignore
3636 xatspec = xattree .get_xatspec (type (value ))
3737
3838 for child_name , child_spec in xatspec .children .items ():
3939 if (child := getattr (value , child_name , None )) is None :
4040 continue
41- if (block_name := child_spec .metadata ["block" ]) not in blocks :
41+ if (block_name := child_spec .metadata ["block" ]) not in blocks : # type: ignore
4242 blocks [block_name ] = {}
4343 match child :
4444 case Component ():
@@ -80,55 +80,57 @@ def has_tdis_dims(value: xr.DataArray) -> bool:
8080 return "nper" in value .dims
8181
8282
83- def _hack_grid_dims (value , field_value ):
84- # terrible hack to convert flat nodes dimension to 3d structured dims.
85- # long term solution for this is to use a custom xarray index. filters
86- # should then have access to all dimensions needed.
87- dims_ = set (field_value .dims ).copy ()
88- parent = value .parent # type: ignore
89- if parent is None :
90- # TODO for standalone packages
91- return field_value
83+ def _hack_structured_grid_dims (value : xr .DataArray , structured_grid_dims : Mapping ):
84+ """
85+ Temporary hack to convert flat nodes dimension to 3d structured dims.
86+ long term solution for this is to use a custom xarray index. filters
87+ should then have access to all dimensions needed.
88+ """
9289
93- if "nper" in dims_ :
94- dims_ .remove ("nper" )
95- shape = (
96- field_value .sizes ["nper" ],
97- parent . dims ["nlay" ],
98- parent . dims ["nrow" ],
99- parent . dims ["ncol" ],
90+ if "nper" in ( old_dims := set ( value . dims ). copy ()) :
91+ old_dims .remove ("nper" )
92+ shape : tuple [ int , ...] = (
93+ value .sizes ["nper" ],
94+ structured_grid_dims ["nlay" ],
95+ structured_grid_dims ["nrow" ],
96+ structured_grid_dims ["ncol" ],
10097 )
101- dims = ("nper" , "nlay" , "nrow" , "ncol" )
98+ dims : tuple [ str , ...] = ("nper" , "nlay" , "nrow" , "ncol" )
10299 coords = {
103- "nper" : field_value .coords ["nper" ],
104- "nlay" : range (parent . dims ["nlay" ]),
105- "nrow" : range (parent . dims ["nrow" ]),
106- "ncol" : range (parent . dims ["ncol" ]),
100+ "nper" : value .coords ["nper" ],
101+ "nlay" : range (structured_grid_dims ["nlay" ]),
102+ "nrow" : range (structured_grid_dims ["nrow" ]),
103+ "ncol" : range (structured_grid_dims ["ncol" ]),
107104 }
108105 else :
109106 shape = (
110- parent . dims ["nlay" ],
111- parent . dims ["nrow" ],
112- parent . dims ["ncol" ],
107+ structured_grid_dims ["nlay" ],
108+ structured_grid_dims ["nrow" ],
109+ structured_grid_dims ["ncol" ],
113110 )
114111 dims = ("nlay" , "nrow" , "ncol" )
115112 coords = {
116- "nlay" : range (parent . dims ["nlay" ]),
117- "nrow" : range (parent . dims ["nrow" ]),
118- "ncol" : range (parent . dims ["ncol" ]),
113+ "nlay" : range (structured_grid_dims ["nlay" ]),
114+ "nrow" : range (structured_grid_dims ["nrow" ]),
115+ "ncol" : range (structured_grid_dims ["ncol" ]),
119116 }
120117
121- if dims_ == {"nodes" }:
122- field_value = xr .DataArray (
123- field_value .data .reshape (shape ),
118+ if old_dims == {"nodes" }:
119+ value = xr .DataArray (
120+ value .data .reshape (shape ),
124121 dims = dims ,
125122 coords = coords ,
126123 )
127124
128- return field_value
125+ return value
129126
130127
131- def unstructure_field (name : str , value : Any ) -> tuple [str , Any ]:
128+ def unstructure_field (
129+ name : str ,
130+ value : Any ,
131+ # TODO: temporary, remove not needed
132+ structured_grid_dims : Mapping | None ,
133+ ) -> tuple [str , Any ]:
132134 """
133135 Convert:
134136
@@ -166,19 +168,34 @@ def unstructure_field(name: str, value: Any) -> tuple[str, Any]:
166168 return name , value .isoformat ()
167169 case xr .DataArray ():
168170 if name == "auxiliary" :
169- value = tuple (value .values .tolist ())
171+ return name , tuple (value .values .tolist ())
170172 if has_grid_dims (value ):
171- value = _hack_grid_dims (value , value )
173+ if structured_grid_dims is None :
174+ raise ValueError ("Need structured grid dimension sizes" )
175+ value = _hack_structured_grid_dims (value , structured_grid_dims = structured_grid_dims )
172176 if has_tdis_dims (value ):
173177 value = {kper : value .isel (nper = kper ) for kper in range (value .sizes ["nper" ])}
174178 return name , value
175179 case _:
176180 return name , value
177181
178182
179- def unstructure_block (block : dict [str , Any ]) -> dict [str , Any ]:
183+ def unstructure_block (
184+ block : dict [str , Any ],
185+ # TODO: temporary, remove not needed
186+ structured_grid_dims : Mapping | None ,
187+ ) -> dict [str , Any ]:
180188 """Unstructure a block of data, converting fields to a suitable format."""
181- return dict ([unstructure_field (block .get (field_name , None )) for field_name in block .keys ()])
189+ return dict (
190+ [
191+ unstructure_field (
192+ name = field_name ,
193+ value = block .get (field_name , None ),
194+ structured_grid_dims = structured_grid_dims ,
195+ )
196+ for field_name in block .keys ()
197+ ]
198+ )
182199
183200
184201def _hack_field_metadata (
@@ -195,8 +212,8 @@ def _hack_field_metadata(
195212
196213def segment_period_data (block : dict [str , Any ], cls : type [Component ]) -> dict [str , dict [str , Any ]]:
197214 """Partition period data by stress period"""
198- arrays = {}
199- blocks = {}
215+ arrays = {} # type: ignore
216+ blocks = {} # type: ignore
200217 period = PERIOD .upper ()
201218
202219 for arr_name , periods in block .items ():
@@ -220,14 +237,26 @@ def unstructure_component(value: Component) -> dict[str, Any]:
220237 data = value .to_dict (blocks = True )
221238 blocks : dict [str , dict [str , Any ]] = {}
222239 blocks .update (binding_blocks := get_binding_blocks (value ))
240+
241+ # temporary hack! TODO remove once we have a structured grid index
242+ if "nlay" in value .data .dims : # type: ignore
243+ structured_grid_dims = value .data .dims # type: ignore
244+ elif value .data .parent is not None and "nlay" in value .data .parent .dims : # type: ignore
245+ structured_grid_dims = value .data .parent .dims # type: ignore
246+ else :
247+ structured_grid_dims = None
248+
223249 blocks .update (
224250 {
225- block_name : unstructure_block (data [block_name ])
251+ block_name : unstructure_block (
252+ data [block_name ], structured_grid_dims = structured_grid_dims
253+ )
226254 for block_name in dfn .blocks .keys ()
227255 if block_name not in binding_blocks
228256 }
229257 )
230258 if period_block := blocks .pop (PERIOD , None ):
259+ period_block = {k : v for k , v in period_block .items () if v is not None }
231260 blocks .update (segment_period_data (period_block , cls ))
232261
233262 # total temporary hack! manually set solutiongroup 1.
0 commit comments