Skip to content

Commit c4cd362

Browse files
authored
TST: interpolate: parametrize tests on griddata (scipy#21815)
1 parent 8db8672 commit c4cd362

File tree

1 file changed

+75
-66
lines changed

1 file changed

+75
-66
lines changed

scipy/interpolate/tests/test_ndgriddata.py

Lines changed: 75 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@
1515
"interpolator", [NearestNDInterpolator, LinearNDInterpolator,
1616
CloughTocher2DInterpolator]
1717
)
18+
parametrize_methods = pytest.mark.parametrize(
19+
'method',
20+
('nearest', 'linear', 'cubic'),
21+
)
22+
parametrize_rescale = pytest.mark.parametrize(
23+
'rescale',
24+
(True, False),
25+
)
26+
1827

1928
class TestGriddata:
2029
def test_fill_value(self):
@@ -27,75 +36,75 @@ def test_fill_value(self):
2736
yi = griddata(x, y, [(1,1), (1,2), (0,0)])
2837
xp_assert_equal(yi, [np.nan, np.nan, 1])
2938

30-
def test_alternative_call(self):
39+
@parametrize_methods
40+
@parametrize_rescale
41+
def test_alternative_call(self, method, rescale):
3142
x = np.array([(0,0), (-0.5,-0.5), (-0.5,0.5), (0.5, 0.5), (0.25, 0.3)],
3243
dtype=np.float64)
3344
y = (np.arange(x.shape[0], dtype=np.float64)[:,None]
3445
+ np.array([0,1])[None,:])
3546

36-
for method in ('nearest', 'linear', 'cubic'):
37-
for rescale in (True, False):
38-
msg = repr((method, rescale))
39-
yi = griddata((x[:,0], x[:,1]), y, (x[:,0], x[:,1]), method=method,
40-
rescale=rescale)
41-
xp_assert_close(y, yi, atol=1e-14, err_msg=msg)
47+
msg = repr((method, rescale))
48+
yi = griddata((x[:,0], x[:,1]), y, (x[:,0], x[:,1]), method=method,
49+
rescale=rescale)
50+
xp_assert_close(y, yi, atol=1e-14, err_msg=msg)
4251

43-
def test_multivalue_2d(self):
52+
@parametrize_methods
53+
@parametrize_rescale
54+
def test_multivalue_2d(self, method, rescale):
4455
x = np.array([(0,0), (-0.5,-0.5), (-0.5,0.5), (0.5, 0.5), (0.25, 0.3)],
4556
dtype=np.float64)
4657
y = (np.arange(x.shape[0], dtype=np.float64)[:,None]
4758
+ np.array([0,1])[None,:])
4859

49-
for method in ('nearest', 'linear', 'cubic'):
50-
for rescale in (True, False):
51-
msg = repr((method, rescale))
52-
yi = griddata(x, y, x, method=method, rescale=rescale)
53-
xp_assert_close(y, yi, atol=1e-14, err_msg=msg)
60+
msg = repr((method, rescale))
61+
yi = griddata(x, y, x, method=method, rescale=rescale)
62+
xp_assert_close(y, yi, atol=1e-14, err_msg=msg)
5463

55-
def test_multipoint_2d(self):
64+
@parametrize_methods
65+
@parametrize_rescale
66+
def test_multipoint_2d(self, method, rescale):
5667
x = np.array([(0,0), (-0.5,-0.5), (-0.5,0.5), (0.5, 0.5), (0.25, 0.3)],
5768
dtype=np.float64)
5869
y = np.arange(x.shape[0], dtype=np.float64)
5970

6071
xi = x[:,None,:] + np.array([0,0,0])[None,:,None]
6172

62-
for method in ('nearest', 'linear', 'cubic'):
63-
for rescale in (True, False):
64-
msg = repr((method, rescale))
65-
yi = griddata(x, y, xi, method=method, rescale=rescale)
73+
msg = repr((method, rescale))
74+
yi = griddata(x, y, xi, method=method, rescale=rescale)
6675

67-
assert yi.shape == (5, 3), msg
68-
xp_assert_close(yi, np.tile(y[:,None], (1, 3)),
69-
atol=1e-14, err_msg=msg)
76+
assert yi.shape == (5, 3), msg
77+
xp_assert_close(yi, np.tile(y[:,None], (1, 3)),
78+
atol=1e-14, err_msg=msg)
7079

71-
def test_complex_2d(self):
80+
@parametrize_methods
81+
@parametrize_rescale
82+
def test_complex_2d(self, method, rescale):
7283
x = np.array([(0,0), (-0.5,-0.5), (-0.5,0.5), (0.5, 0.5), (0.25, 0.3)],
7384
dtype=np.float64)
7485
y = np.arange(x.shape[0], dtype=np.float64)
7586
y = y - 2j*y[::-1]
7687

7788
xi = x[:,None,:] + np.array([0,0,0])[None,:,None]
7889

79-
for method in ('nearest', 'linear', 'cubic'):
80-
for rescale in (True, False):
81-
msg = repr((method, rescale))
82-
yi = griddata(x, y, xi, method=method, rescale=rescale)
90+
msg = repr((method, rescale))
91+
yi = griddata(x, y, xi, method=method, rescale=rescale)
8392

84-
assert yi.shape == (5, 3)
85-
xp_assert_close(yi, np.tile(y[:,None], (1, 3)),
86-
atol=1e-14, err_msg=msg)
93+
assert yi.shape == (5, 3)
94+
xp_assert_close(yi, np.tile(y[:,None], (1, 3)),
95+
atol=1e-14, err_msg=msg)
8796

88-
def test_1d(self):
97+
@parametrize_methods
98+
def test_1d(self, method):
8999
x = np.array([1, 2.5, 3, 4.5, 5, 6])
90100
y = np.array([1, 2, 0, 3.9, 2, 1])
91101

92-
for method in ('nearest', 'linear', 'cubic'):
93-
xp_assert_close(griddata(x, y, x, method=method), y,
94-
err_msg=method, atol=1e-14)
95-
xp_assert_close(griddata(x.reshape(6, 1), y, x, method=method), y,
96-
err_msg=method, atol=1e-14)
97-
xp_assert_close(griddata((x,), y, (x,), method=method), y,
98-
err_msg=method, atol=1e-14)
102+
xp_assert_close(griddata(x, y, x, method=method), y,
103+
err_msg=method, atol=1e-14)
104+
xp_assert_close(griddata(x.reshape(6, 1), y, x, method=method), y,
105+
err_msg=method, atol=1e-14)
106+
xp_assert_close(griddata((x,), y, (x,), method=method), y,
107+
err_msg=method, atol=1e-14)
99108

100109
def test_1d_borders(self):
101110
# Test for nearest neighbor case with xi outside
@@ -119,19 +128,20 @@ def test_1d_borders(self):
119128
err_msg=method,
120129
atol=1e-14)
121130

122-
def test_1d_unsorted(self):
131+
@parametrize_methods
132+
def test_1d_unsorted(self, method):
123133
x = np.array([2.5, 1, 4.5, 5, 6, 3])
124134
y = np.array([1, 2, 0, 3.9, 2, 1])
125135

126-
for method in ('nearest', 'linear', 'cubic'):
127-
xp_assert_close(griddata(x, y, x, method=method), y,
128-
err_msg=method, atol=1e-10)
129-
xp_assert_close(griddata(x.reshape(6, 1), y, x, method=method), y,
130-
err_msg=method, atol=1e-10)
131-
xp_assert_close(griddata((x,), y, (x,), method=method), y,
132-
err_msg=method, atol=1e-10)
136+
xp_assert_close(griddata(x, y, x, method=method), y,
137+
err_msg=method, atol=1e-10)
138+
xp_assert_close(griddata(x.reshape(6, 1), y, x, method=method), y,
139+
err_msg=method, atol=1e-10)
140+
xp_assert_close(griddata((x,), y, (x,), method=method), y,
141+
err_msg=method, atol=1e-10)
133142

134-
def test_square_rescale_manual(self):
143+
@parametrize_methods
144+
def test_square_rescale_manual(self, method):
135145
points = np.array([(0,0), (0,100), (10,100), (10,0), (1, 5)], dtype=np.float64)
136146
points_rescaled = np.array([(0,0), (0,1), (1,1), (1,0), (0.1, 0.05)],
137147
dtype=np.float64)
@@ -143,16 +153,16 @@ def test_square_rescale_manual(self):
143153
yy = yy.ravel()
144154
xi = np.array([xx, yy]).T.copy()
145155

146-
for method in ('nearest', 'linear', 'cubic'):
147-
msg = method
148-
zi = griddata(points_rescaled, values, xi/np.array([10, 100.]),
149-
method=method)
150-
zi_rescaled = griddata(points, values, xi, method=method,
151-
rescale=True)
152-
xp_assert_close(zi, zi_rescaled, err_msg=msg,
153-
atol=1e-12)
156+
msg = method
157+
zi = griddata(points_rescaled, values, xi/np.array([10, 100.]),
158+
method=method)
159+
zi_rescaled = griddata(points, values, xi, method=method,
160+
rescale=True)
161+
xp_assert_close(zi, zi_rescaled, err_msg=msg,
162+
atol=1e-12)
154163

155-
def test_xi_1d(self):
164+
@parametrize_methods
165+
def test_xi_1d(self, method):
156166
# Check that 1-D xi is interpreted as a coordinate
157167
x = np.array([(0,0), (-0.5,-0.5), (-0.5,0.5), (0.5, 0.5), (0.25, 0.3)],
158168
dtype=np.float64)
@@ -161,17 +171,16 @@ def test_xi_1d(self):
161171

162172
xi = np.array([0.5, 0.5])
163173

164-
for method in ('nearest', 'linear', 'cubic'):
165-
p1 = griddata(x, y, xi, method=method)
166-
p2 = griddata(x, y, xi[None,:], method=method)
167-
xp_assert_close(p1, p2, err_msg=method)
168-
169-
xi1 = np.array([0.5])
170-
xi3 = np.array([0.5, 0.5, 0.5])
171-
assert_raises(ValueError, griddata, x, y, xi1,
172-
method=method)
173-
assert_raises(ValueError, griddata, x, y, xi3,
174-
method=method)
174+
p1 = griddata(x, y, xi, method=method)
175+
p2 = griddata(x, y, xi[None,:], method=method)
176+
xp_assert_close(p1, p2, err_msg=method)
177+
178+
xi1 = np.array([0.5])
179+
xi3 = np.array([0.5, 0.5, 0.5])
180+
assert_raises(ValueError, griddata, x, y, xi1,
181+
method=method)
182+
assert_raises(ValueError, griddata, x, y, xi3,
183+
method=method)
175184

176185

177186
class TestNearestNDInterpolator:

0 commit comments

Comments
 (0)