Skip to content

Commit 79e7eb3

Browse files
committed
update tests for masked integration
1 parent f11e267 commit 79e7eb3

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

tests/test_integration.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,13 @@ def pts_dg(coords_dg):
2727

2828
return pts
2929

30+
def apply_mask(arr, mask):
31+
"""Apply a mask to arr and return 0.0 where mask is True"""
3032

33+
arr_masked = arr.copy()
34+
arr_masked[mask] = 0.0
35+
36+
return arr_masked
3137

3238
def test_flowmap(pts_dg, fm_data):
3339

@@ -46,31 +52,42 @@ def test_flowmap_n(pts_dg, fm_n_data):
4652
assert np.allclose(fm_n,fm_n_data)
4753

4854

49-
def test_flowmap_grid_2D(coords_dg, fm_data):
55+
def test_flowmap_grid_2D(coords_dg, fm_data, mask_dg):
5056

5157
x,y = coords_dg
5258
fm = flowmap_grid_2D(funcptr, t0, T, x, y, params).astype(np.float32)
5359

54-
assert np.allclose(fm,fm_data)
60+
assert np.allclose(fm, fm_data)
5561

62+
fm_masked = flowmap_grid_2D(funcptr, t0, T, x, y, params, mask=mask_dg).astype(np.float32)
5663

57-
def test_flowmap_aux_grid_2D(coords_dg, fm_aux_data):
64+
assert np.allclose(fm_masked, apply_mask(fm_data, mask_dg))
65+
66+
def test_flowmap_aux_grid_2D(coords_dg, fm_aux_data, mask_dg):
5867

5968
x,y = coords_dg
6069
fm_aux = flowmap_aux_grid_2D(funcptr, t0, T, x, y, params).astype(np.float32)
6170

6271
assert np.allclose(fm_aux,fm_aux_data)
6372

73+
fm_aux_masked = flowmap_aux_grid_2D(funcptr, t0, T, x, y, params, mask=mask_dg).astype(np.float32)
74+
75+
assert np.allclose(fm_aux_masked, apply_mask(fm_aux_data, mask_dg))
6476

65-
def test_flowmap_n_grid_2D(coords_dg, fm_n_data):
77+
78+
def test_flowmap_n_grid_2D(coords_dg, fm_n_data, mask_dg):
6679

6780
x,y = coords_dg
6881
t_eval_expected = params[0]*np.linspace(t0,t0+T,n)
69-
fm_n, t_eval = flowmap_n_grid_2D(funcptr, t0, T, x, y, params, n = n)
82+
fm_n, t_eval = flowmap_n_grid_2D(funcptr, t0, T, x, y, params, n=n)
7083

7184
assert np.allclose(t_eval_expected,t_eval)
7285
assert np.allclose(fm_n.astype(np.float32),fm_n_data)
7386

87+
fm_n_masked, _ = flowmap_n_grid_2D(funcptr, t0, T, x, y, params, n=n, mask=mask_dg)
88+
89+
assert np.allclose(fm_n_masked.astype(np.float32), apply_mask(fm_n_data, mask_dg))
90+
7491

7592
def test_flowmap_composition_initial(coords_dg, fm_ci_data, fms_ci_data):
7693

0 commit comments

Comments
 (0)