Skip to content

Commit 5423882

Browse files
committed
add tests
1 parent 99a1d3a commit 5423882

File tree

1 file changed

+376
-0
lines changed

1 file changed

+376
-0
lines changed

test/test_converter_structure.py

Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
"""
2+
Tests for flopy4.mf6.converter.structure module.
3+
4+
Integration tests for the refactored structure_array function with various input formats
5+
using real flopy4 components.
6+
"""
7+
8+
import numpy as np
9+
import sparse
10+
import xarray as xr
11+
12+
from flopy4.mf6.converter.structure import (
13+
_detect_grid_reshape,
14+
_fill_forward_time,
15+
_reshape_grid,
16+
_to_xarray,
17+
_validate_duck_array,
18+
)
19+
from flopy4.mf6.gwf.chd import Chd
20+
from flopy4.mf6.gwf.dis import Dis
21+
from flopy4.mf6.gwf.ic import Ic
22+
from flopy4.mf6.gwf.npf import Npf
23+
from flopy4.mf6.gwf.rch import Rch
24+
25+
26+
class TestHelperFunctions:
27+
"""Test helper functions that don't require full xattree setup."""
28+
29+
def test_detect_grid_reshape_structured_to_flat_3d(self):
30+
"""Test detection of (nlay, nrow, ncol) -> (nodes,) reshape."""
31+
value_shape = (2, 10, 10)
32+
expected_dims = ["nodes"]
33+
dim_dict = {"nlay": 2, "nrow": 10, "ncol": 10, "nodes": 200}
34+
35+
needs_reshape, target_shape = _detect_grid_reshape(value_shape, expected_dims, dim_dict)
36+
37+
assert needs_reshape is True
38+
assert target_shape == (200,)
39+
40+
def test_detect_grid_reshape_structured_to_flat_4d(self):
41+
"""Test detection of (nper, nlay, nrow, ncol) -> (nper, nodes) reshape."""
42+
value_shape = (3, 2, 10, 10)
43+
expected_dims = ["nper", "nodes"]
44+
dim_dict = {"nper": 3, "nlay": 2, "nrow": 10, "ncol": 10, "nodes": 200}
45+
46+
needs_reshape, target_shape = _detect_grid_reshape(value_shape, expected_dims, dim_dict)
47+
48+
assert needs_reshape is True
49+
assert target_shape == (3, 200)
50+
51+
def test_detect_grid_reshape_no_reshape_needed(self):
52+
"""Test when no reshape is needed."""
53+
value_shape = (100,)
54+
expected_dims = ["nodes"]
55+
dim_dict = {"nodes": 100}
56+
57+
needs_reshape, target_shape = _detect_grid_reshape(value_shape, expected_dims, dim_dict)
58+
59+
assert needs_reshape is False
60+
assert target_shape is None
61+
62+
def test_reshape_grid_numpy_array(self):
63+
"""Test reshaping numpy array."""
64+
data = np.ones((2, 10, 10))
65+
target_shape = (200,)
66+
67+
result = _reshape_grid(data, target_shape)
68+
69+
assert isinstance(result, np.ndarray)
70+
assert result.shape == (200,)
71+
assert np.all(result == 1.0)
72+
73+
def test_reshape_grid_xarray(self):
74+
"""Test reshaping xarray DataArray."""
75+
data = xr.DataArray(np.ones((2, 10, 10)), dims=["nlay", "nrow", "ncol"])
76+
target_shape = (200,)
77+
target_dims = ["nodes"]
78+
79+
result = _reshape_grid(data, target_shape, ["nlay", "nrow", "ncol"], target_dims)
80+
81+
assert isinstance(result, xr.DataArray)
82+
assert result.shape == (200,)
83+
assert result.dims == ("nodes",)
84+
85+
def test_validate_duck_array_numpy_correct_shape(self):
86+
"""Test validating numpy array with correct shape."""
87+
value = np.ones((3, 100))
88+
expected_dims = ["nper", "nodes"]
89+
expected_shape = (3, 100)
90+
dim_dict = {"nper": 3, "nodes": 100}
91+
92+
result = _validate_duck_array(value, expected_dims, expected_shape, dim_dict)
93+
94+
assert np.array_equal(result, value)
95+
96+
def test_validate_duck_array_xarray_correct_dims(self):
97+
"""Test validating xarray with correct dimensions."""
98+
value = xr.DataArray(np.ones((3, 100)), dims=["nper", "nodes"])
99+
expected_dims = ["nper", "nodes"]
100+
expected_shape = (3, 100)
101+
dim_dict = {"nper": 3, "nodes": 100}
102+
103+
result = _validate_duck_array(value, expected_dims, expected_shape, dim_dict)
104+
105+
assert isinstance(result, xr.DataArray)
106+
assert result.dims == ("nper", "nodes")
107+
108+
def test_fill_forward_time_numpy(self):
109+
"""Test adding nper dimension to numpy array."""
110+
data = np.ones((100,))
111+
dims = ["nper", "nodes"]
112+
nper = 3
113+
114+
result = _fill_forward_time(data, dims, nper)
115+
116+
assert result.shape == (3, 100)
117+
assert np.all(result == 1.0)
118+
119+
def test_fill_forward_time_xarray(self):
120+
"""Test adding nper dimension to xarray."""
121+
data = xr.DataArray(np.ones((100,)), dims=["nodes"])
122+
dims = ["nper", "nodes"]
123+
nper = 3
124+
125+
result = _fill_forward_time(data, dims, nper)
126+
127+
assert isinstance(result, xr.DataArray)
128+
assert result.shape == (3, 100)
129+
assert result.dims == ("nper", "nodes")
130+
131+
def test_to_xarray_numpy_array(self):
132+
"""Test wrapping numpy array in xarray."""
133+
data = np.ones((3, 100))
134+
dims = ["nper", "nodes"]
135+
coords = {"nper": np.arange(3), "nodes": np.arange(100)}
136+
attrs = {"units": "m"}
137+
138+
result = _to_xarray(data, dims, coords, attrs)
139+
140+
assert isinstance(result, xr.DataArray)
141+
assert result.dims == ("nper", "nodes")
142+
assert "nper" in result.coords
143+
assert result.attrs["units"] == "m"
144+
145+
146+
class TestDisComponent:
147+
"""Test structure_array with Dis component (array dims)."""
148+
149+
def test_dis_with_scalar_delr(self):
150+
"""Test Dis with scalar delr (broadcast to ncol)."""
151+
dis = Dis(nlay=1, nrow=10, ncol=10, delr=1.0, delc=1.0)
152+
153+
assert hasattr(dis, "delr")
154+
# Can be numpy or xarray depending on component configuration
155+
assert isinstance(dis.delr, (np.ndarray, xr.DataArray))
156+
if isinstance(dis.delr, xr.DataArray):
157+
assert dis.delr.shape == (10,)
158+
assert np.all(dis.delr.values == 1.0)
159+
else:
160+
assert dis.delr.shape == (10,)
161+
assert np.all(dis.delr == 1.0)
162+
163+
def test_dis_with_list_delr(self):
164+
"""Test Dis with list delr."""
165+
dis = Dis(nlay=1, nrow=10, ncol=10, delr=[1.0] * 10, delc=[2.0] * 10)
166+
167+
assert dis.delr.shape == (10,)
168+
assert np.all(dis.delr == 1.0)
169+
assert dis.delc.shape == (10,)
170+
assert np.all(dis.delc == 2.0)
171+
172+
def test_dis_with_numpy_array(self):
173+
"""Test Dis with numpy array input."""
174+
delr_array = np.linspace(1.0, 2.0, 10)
175+
dis = Dis(nlay=1, nrow=10, ncol=10, delr=delr_array, delc=1.0)
176+
177+
assert dis.delr.shape == (10,)
178+
assert np.allclose(dis.delr, delr_array)
179+
180+
181+
class TestIcComponent:
182+
"""Test structure_array with Ic component (initial conditions)."""
183+
184+
def test_ic_with_scalar_strt(self):
185+
"""Test IC with scalar starting head (broadcast to all nodes)."""
186+
ic = Ic(dims={"nlay": 1, "nrow": 10, "ncol": 10, "nodes": 100}, strt=100.0)
187+
188+
assert hasattr(ic, "strt")
189+
assert isinstance(ic.strt, (np.ndarray, xr.DataArray))
190+
if isinstance(ic.strt, xr.DataArray):
191+
assert ic.strt.shape == (100,)
192+
assert np.all(ic.strt.values == 100.0)
193+
else:
194+
assert ic.strt.shape == (100,)
195+
assert np.all(ic.strt == 100.0)
196+
197+
def test_ic_with_numpy_array(self):
198+
"""Test IC with numpy array."""
199+
strt_array = np.ones((100,)) * 50.0
200+
ic = Ic(dims={"nodes": 100}, strt=strt_array)
201+
202+
assert ic.strt.shape == (100,)
203+
assert np.all(ic.strt == 50.0)
204+
205+
def test_ic_with_structured_array(self):
206+
"""Test IC with structured grid array (should reshape to flat)."""
207+
# This would require grid reshaping functionality
208+
strt_3d = np.ones((1, 10, 10)) * 100.0
209+
ic = Ic(dims={"nlay": 1, "nrow": 10, "ncol": 10, "nodes": 100}, strt=strt_3d)
210+
211+
# Should be reshaped to flat nodes
212+
assert ic.strt.shape == (100,)
213+
assert np.all(ic.strt == 100.0)
214+
215+
216+
class TestNpfComponent:
217+
"""Test structure_array with Npf component."""
218+
219+
def test_npf_with_scalar_k(self):
220+
"""Test NPF with scalar hydraulic conductivity."""
221+
npf = Npf(dims={"nodes": 100}, k=1.0)
222+
223+
assert hasattr(npf, "k")
224+
assert isinstance(npf.k, (np.ndarray, xr.DataArray))
225+
if isinstance(npf.k, xr.DataArray):
226+
assert npf.k.shape == (100,)
227+
assert np.all(npf.k.values == 1.0)
228+
else:
229+
assert npf.k.shape == (100,)
230+
assert np.all(npf.k == 1.0)
231+
232+
def test_npf_with_layered_k(self):
233+
"""Test NPF with layered k values."""
234+
k_3d = np.ones((2, 10, 10))
235+
k_3d[0] = 10.0
236+
k_3d[1] = 1.0
237+
238+
npf = Npf(dims={"nlay": 2, "nrow": 10, "ncol": 10, "nodes": 200}, k=k_3d)
239+
240+
assert npf.k.shape == (200,)
241+
# First layer (nodes 0-99) should be 10.0
242+
assert np.all(npf.k[:100] == 10.0)
243+
# Second layer (nodes 100-199) should be 1.0
244+
assert np.all(npf.k[100:] == 1.0)
245+
246+
247+
class TestChdComponent:
248+
"""Test structure_array with Chd component (stress period data)."""
249+
250+
def test_chd_with_dict_format(self):
251+
"""Test CHD with dict format and cellid: value."""
252+
chd = Chd(
253+
dims={"nlay": 1, "nrow": 10, "ncol": 10, "nper": 3, "nodes": 100},
254+
head={0: {(0, 0, 0): 1.0, (0, 9, 9): 0.0}},
255+
)
256+
257+
assert hasattr(chd, "head")
258+
assert chd.head.shape == (3, 100)
259+
# SP 0 should have the values
260+
assert chd.head[0, 0] == 1.0
261+
assert chd.head[0, 99] == 0.0
262+
# SP 1 and 2 should fill forward from SP 0
263+
assert chd.head[1, 0] == 1.0
264+
assert chd.head[2, 99] == 0.0
265+
266+
def test_chd_with_star_key(self):
267+
"""Test CHD with '*' key for all stress periods."""
268+
chd = Chd(
269+
dims={"nlay": 1, "nrow": 10, "ncol": 10, "nper": 3, "nodes": 100},
270+
head={"*": {(0, 0, 0): 5.0}},
271+
)
272+
273+
# '*' should map to period 0 and fill forward
274+
assert chd.head[0, 0] == 5.0
275+
assert chd.head[1, 0] == 5.0
276+
assert chd.head[2, 0] == 5.0
277+
278+
def test_chd_with_fill_forward(self):
279+
"""Test CHD with fill-forward behavior."""
280+
chd = Chd(
281+
dims={"nlay": 1, "nrow": 10, "ncol": 10, "nper": 10, "nodes": 100},
282+
head={0: {(0, 0, 0): 1.0}, 5: {(0, 0, 0): 2.0}},
283+
)
284+
285+
# SP 0-4 should have first value
286+
assert chd.head[0, 0] == 1.0
287+
assert chd.head[4, 0] == 1.0
288+
289+
# SP 5+ should have second value
290+
assert chd.head[5, 0] == 2.0
291+
assert chd.head[9, 0] == 2.0
292+
293+
294+
class TestRchComponent:
295+
"""Test structure_array with Rch component (recharge)."""
296+
297+
def test_rch_with_scalar_dict(self):
298+
"""Test RCH with scalar values per stress period."""
299+
rch = Rch(
300+
dims={"nlay": 1, "nrow": 10, "ncol": 10, "nper": 3, "nodes": 100},
301+
recharge={0: 0.004, 1: 0.002},
302+
)
303+
304+
assert hasattr(rch, "recharge")
305+
# Should broadcast scalar to all nodes
306+
assert rch.recharge.shape == (3, 100)
307+
assert np.all(rch.recharge[0] == 0.004)
308+
assert np.all(rch.recharge[1] == 0.002)
309+
# SP 2 should fill forward from SP 1
310+
assert np.all(rch.recharge[2] == 0.002)
311+
312+
313+
class TestSparseArrays:
314+
"""Test sparse array creation for large arrays."""
315+
316+
def test_sparse_array_creation(self):
317+
"""Test that large sparse arrays use COO format."""
318+
# Create a CHD with very large grid (exceeds threshold)
319+
from flopy4.mf6.config import SPARSE_THRESHOLD
320+
321+
nper = 10
322+
nodes = 100000 # Large grid
323+
total_size = nper * nodes
324+
325+
if total_size > SPARSE_THRESHOLD:
326+
chd = Chd(
327+
dims={"nlay": 1, "nrow": 1000, "ncol": 100, "nper": nper, "nodes": nodes},
328+
head={0: {(0, 0, 0): 1.0, (0, 999, 99): 0.0}},
329+
)
330+
331+
# Should create sparse array (possibly wrapped in xarray)
332+
if isinstance(chd.head, xr.DataArray):
333+
# If wrapped in xarray, check the underlying data
334+
assert isinstance(chd.head.data, sparse.COO)
335+
assert chd.head.shape == (nper, nodes)
336+
else:
337+
assert isinstance(chd.head, sparse.COO)
338+
assert chd.head.shape == (nper, nodes)
339+
340+
341+
class TestXarrayOutput:
342+
"""Test xarray output functionality."""
343+
344+
def test_xarray_output_disabled_by_default(self):
345+
"""Test that xarray output is disabled by default for backward compatibility."""
346+
ic = Ic(dims={"nodes": 100}, strt=100.0)
347+
348+
# Default is return_xarray=False, so should get numpy
349+
# (this is set in the field converter, not directly testable here)
350+
assert isinstance(ic.strt, (np.ndarray, sparse.COO)) or isinstance(ic.strt, xr.DataArray)
351+
352+
353+
class TestEdgeCases:
354+
"""Test edge cases and special scenarios."""
355+
356+
def test_empty_dict_creates_default_array(self):
357+
"""Test that empty dict creates array with default values."""
358+
ic = Ic(dims={"nodes": 100}, strt={})
359+
360+
# Should create array with defaults
361+
assert hasattr(ic, "strt")
362+
assert ic.strt.shape == (100,)
363+
364+
def test_mixed_dict_value_types(self):
365+
"""Test dict with mixed value types (scalar, array)."""
366+
chd = Chd(
367+
dims={"nlay": 1, "nrow": 10, "ncol": 10, "nper": 10, "nodes": 100},
368+
head={
369+
0: {(0, 0, 0): 1.0}, # Dict with cellid
370+
5: {(0, 0, 0): 2.0, (0, 9, 9): 0.5}, # Multiple cellids
371+
},
372+
)
373+
374+
assert chd.head[0, 0] == 1.0
375+
assert chd.head[5, 0] == 2.0
376+
assert chd.head[5, 99] == 0.5

0 commit comments

Comments
 (0)