@@ -27,7 +27,13 @@ def pts_dg(coords_dg):
2727
2828 return pts
2929
30+ def apply_mask (arr , mask ):
31+ """Apply a mask to arr and return 0.0 where mask is True"""
3032
33+ arr_masked = arr .copy ()
34+ arr_masked [mask ] = 0.0
35+
36+ return arr_masked
3137
3238def test_flowmap (pts_dg , fm_data ):
3339
@@ -46,31 +52,42 @@ def test_flowmap_n(pts_dg, fm_n_data):
4652 assert np .allclose (fm_n ,fm_n_data )
4753
4854
49- def test_flowmap_grid_2D (coords_dg , fm_data ):
55+ def test_flowmap_grid_2D (coords_dg , fm_data , mask_dg ):
5056
5157 x ,y = coords_dg
5258 fm = flowmap_grid_2D (funcptr , t0 , T , x , y , params ).astype (np .float32 )
5359
54- assert np .allclose (fm ,fm_data )
60+ assert np .allclose (fm , fm_data )
5561
62+ fm_masked = flowmap_grid_2D (funcptr , t0 , T , x , y , params , mask = mask_dg ).astype (np .float32 )
5663
57- def test_flowmap_aux_grid_2D (coords_dg , fm_aux_data ):
64+ assert np .allclose (fm_masked , apply_mask (fm_data , mask_dg ))
65+
66+ def test_flowmap_aux_grid_2D (coords_dg , fm_aux_data , mask_dg ):
5867
5968 x ,y = coords_dg
6069 fm_aux = flowmap_aux_grid_2D (funcptr , t0 , T , x , y , params ).astype (np .float32 )
6170
6271 assert np .allclose (fm_aux ,fm_aux_data )
6372
73+ fm_aux_masked = flowmap_aux_grid_2D (funcptr , t0 , T , x , y , params , mask = mask_dg ).astype (np .float32 )
74+
75+ assert np .allclose (fm_aux_masked , apply_mask (fm_aux_data , mask_dg ))
6476
65- def test_flowmap_n_grid_2D (coords_dg , fm_n_data ):
77+
78+ def test_flowmap_n_grid_2D (coords_dg , fm_n_data , mask_dg ):
6679
6780 x ,y = coords_dg
6881 t_eval_expected = params [0 ]* np .linspace (t0 ,t0 + T ,n )
69- fm_n , t_eval = flowmap_n_grid_2D (funcptr , t0 , T , x , y , params , n = n )
82+ fm_n , t_eval = flowmap_n_grid_2D (funcptr , t0 , T , x , y , params , n = n )
7083
7184 assert np .allclose (t_eval_expected ,t_eval )
7285 assert np .allclose (fm_n .astype (np .float32 ),fm_n_data )
7386
87+ fm_n_masked , _ = flowmap_n_grid_2D (funcptr , t0 , T , x , y , params , n = n , mask = mask_dg )
88+
89+ assert np .allclose (fm_n_masked .astype (np .float32 ), apply_mask (fm_n_data , mask_dg ))
90+
7491
7592def test_flowmap_composition_initial (coords_dg , fm_ci_data , fms_ci_data ):
7693
0 commit comments