Skip to content

Commit 15a15b9

Browse files
committed
cleanup
1 parent d1300c2 commit 15a15b9

File tree

3 files changed

+21
-27
lines changed

3 files changed

+21
-27
lines changed

flopy4/mf6/codec/writer/filters.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,15 @@ def dataset2list(value: xr.Dataset):
201201
if value is None or not any(value.data_vars):
202202
return
203203

204-
first = next(iter(value.data_vars.values()))
205-
is_union = first.dtype.type is np.str_
204+
# special case OC for now.
205+
is_oc = all(
206+
str(v.name).startswith("save_") or str(v.name).startswith("print_")
207+
for v in value.data_vars.values()
208+
)
206209

207-
if first.ndim == 0: # handle scalar
208-
if is_union:
210+
# handle scalar
211+
if (first := next(iter(value.data_vars.values()))).ndim == 0:
212+
if is_oc:
209213
for name in value.data_vars.keys():
210214
val = value[name]
211215
val = val.item() if val.shape == () else val
@@ -230,7 +234,7 @@ def dataset2list(value: xr.Dataset):
230234
has_spatial_dims = len(spatial_dims) > 0
231235
indices = np.where(combined_mask)
232236
for i in range(len(indices[0])):
233-
if is_union:
237+
if is_oc:
234238
for name in value.data_vars.keys():
235239
val = value[name][tuple(idx[i] for idx in indices)]
236240
val = val.item() if val.shape == () else val

flopy4/mf6/converter.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def unstructure_component(value: Component) -> dict[str, Any]:
110110
blocks[block_name] = {}
111111
period_data = {}
112112
period_blocks = {} # type: ignore
113+
period_block_name = None
113114

114115
for field_name in block.keys():
115116
# Skip child components that have been processed as bindings
@@ -150,51 +151,43 @@ def unstructure_component(value: Component) -> dict[str, Any]:
150151
field_value,
151152
structured_grid_dims=value.parent.data.dims, # type: ignore
152153
)
153-
154+
if "period" in block_name:
155+
period_block_name = block_name
154156
period_data[field_name] = {
155157
kper: field_value.isel(nper=kper)
156158
for kper in range(field_value.sizes["nper"])
157159
}
158160
else:
159-
# TODO why not putting in block here but doing below? how does this even work
160-
if np.issubdtype(field_value.dtype, np.str_):
161-
period_data[field_name] = {
162-
kper: field_value[kper] for kper in range(field_value.sizes["nper"])
163-
}
164-
else:
165-
if block_name not in period_data:
166-
period_data[block_name] = {}
167-
period_data[block_name][field_name] = field_value # type: ignore
161+
blocks[block_name][field_name] = field_value
168162
else:
169163
if field_value is not None:
164+
# only include boolean fields (keywords) if true
170165
if isinstance(field_value, bool):
171166
if field_value:
172167
blocks[block_name][field_name] = field_value
173168
else:
174169
blocks[block_name][field_name] = field_value
175170

176-
if block_name in period_data and isinstance(period_data[block_name], dict):
177-
dataset = xr.Dataset(period_data[block_name])
178-
blocks[block_name] = {block_name: dataset}
179-
del period_data[block_name]
180-
181171
for arr_name, periods in period_data.items():
182172
for kper, arr in periods.items():
183173
if kper not in period_blocks:
184174
period_blocks[kper] = {}
185175
period_blocks[kper][arr_name] = arr
186176

187177
for kper, block in period_blocks.items():
188-
dataset = xr.Dataset(block)
189-
blocks[f"{block_name} {kper + 1}"] = {block_name: dataset}
178+
assert isinstance(period_block_name, str)
179+
blocks[f"{period_block_name} {kper + 1}"] = {
180+
period_block_name: xr.Dataset(block, coords=block[arr_name].coords)
181+
}
190182

191-
# total temporary hack! manually set solutiongroup 1. still need to support multiple..
183+
# total temporary hack! manually set solutiongroup 1.
184+
# TODO still need to support multiple..
192185
if "solutiongroup" in blocks:
193186
sg = blocks["solutiongroup"]
194187
blocks["solutiongroup 1"] = sg
195188
del blocks["solutiongroup"]
196189

197-
return {name: block for name, block in blocks.items() if name != "period"}
190+
return {name: block for name, block in blocks.items() if name != period_block_name}
198191

199192

200193
def _make_converter() -> Converter:

test/test_codec.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from pprint import pprint
22

3-
import pytest
4-
53
from flopy4.mf6.codec import dumps, loads
64
from flopy4.mf6.converter import COMPONENT_CONVERTER
75

@@ -55,7 +53,6 @@ def test_dumps_ic():
5553
pprint(loaded)
5654

5755

58-
@pytest.mark.xfail(reason="TODO")
5956
def test_dumps_oc():
6057
from flopy4.mf6.gwf import Oc
6158

0 commit comments

Comments
 (0)