Skip to content

Commit 8c0ab2a

Browse files
committed
be more tolerant
1 parent 17fe33e commit 8c0ab2a

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

tests/test_proc_ops.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_scale_linear(tid: MemberId):
3030
op(sample)
3131

3232
expected = xr.DataArray(np.array([[[1, 4, 48], [4, 10, 57]]]), dims=("x", "y", "c"))
33-
xr.testing.assert_allclose(expected, sample.members[tid].data)
33+
xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-6, atol=1e-7)
3434

3535

3636
def test_scale_linear_no_channel(tid: MemberId):
@@ -42,7 +42,7 @@ def test_scale_linear_no_channel(tid: MemberId):
4242
op(sample)
4343

4444
expected = xr.DataArray(np.array([[1, 3, 5], [7, 9, 11]]), dims=("x", "y"))
45-
xr.testing.assert_allclose(expected, sample.members[tid].data)
45+
xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-6, atol=1e-7)
4646

4747

4848
T = TypeVar("T")
@@ -75,7 +75,7 @@ def test_zero_mean_unit_variance(tid: MemberId):
7575
),
7676
dims=("x", "y"),
7777
)
78-
xr.testing.assert_allclose(expected, sample.members[tid].data)
78+
xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-6, atol=1e-7)
7979

8080

8181
def test_zero_mean_unit_variance_fixed(tid: MemberId):
@@ -102,7 +102,7 @@ def test_zero_mean_unit_variance_fixed(tid: MemberId):
102102
)
103103
sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None)
104104
op(sample)
105-
xr.testing.assert_allclose(expected, sample.members[tid].data)
105+
xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-6, atol=1e-7)
106106

107107

108108
def test_zero_mean_unit_across_axes(tid: MemberId):
@@ -123,7 +123,7 @@ def test_zero_mean_unit_across_axes(tid: MemberId):
123123
[(data[i : i + 1] - data[i].mean()) / data[i].std() for i in range(2)], dim="c"
124124
)
125125
op(sample)
126-
xr.testing.assert_allclose(expected, sample.members[tid].data)
126+
xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-6, atol=1e-7)
127127

128128

129129
def test_zero_mean_unit_variance_fixed2(tid: MemberId):
@@ -139,7 +139,7 @@ def test_zero_mean_unit_variance_fixed2(tid: MemberId):
139139
sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None)
140140
expected = xr.DataArray((np_data - mean) / (std + eps), dims=("x", "y"))
141141
op(sample)
142-
xr.testing.assert_allclose(expected, sample.members[tid].data)
142+
xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-6, atol=1e-7)
143143

144144

145145
def test_binarize(tid: MemberId):
@@ -223,7 +223,7 @@ def test_combination_of_op_steps_with_dims_specified(tid: MemberId):
223223
)
224224

225225
op(sample)
226-
xr.testing.assert_allclose(expected, sample.members[tid].data)
226+
xr.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-6, atol=1e-7)
227227

228228

229229
@pytest.mark.parametrize(
@@ -255,7 +255,7 @@ def test_scale_mean_variance(tid: MemberId, axes: Optional[Tuple[AxisId, ...]]):
255255
)
256256
sample.stat = compute_measures(op.required_measures, [sample])
257257
op(sample)
258-
xr.testing.assert_allclose(ref_data, sample.members[tid].data)
258+
xr.testing.assert_allclose(ref_data, sample.members[tid].data, rtol=1e-6, atol=1e-7)
259259

260260

261261
@pytest.mark.parametrize(
@@ -290,11 +290,15 @@ def test_scale_mean_variance_per_channel(tid: MemberId, axes_str: Optional[str])
290290

291291
if axes is not None and AxisId("c") not in axes:
292292
# mean,std per channel should match exactly
293-
xr.testing.assert_allclose(ref_data, sample.members[tid].data)
293+
xr.testing.assert_allclose(
294+
ref_data, sample.members[tid].data, rtol=1e-6, atol=1e-7
295+
)
294296
else:
295297
# mean,std across channels should not match
296298
with pytest.raises(AssertionError):
297-
xr.testing.assert_allclose(ref_data, sample.members[tid].data)
299+
xr.testing.assert_allclose(
300+
ref_data, sample.members[tid].data, rtol=1e-6, atol=1e-7
301+
)
298302

299303

300304
def test_scale_range(tid: MemberId):
@@ -313,7 +317,7 @@ def test_scale_range(tid: MemberId):
313317

314318
op(sample)
315319
# NOTE xarray.testing.assert_allclose compares irrelavant properties here and fails although the result is correct
316-
np.testing.assert_allclose(expected, sample.members[tid].data)
320+
np.testing.assert_allclose(expected, sample.members[tid].data, rtol=1e-6, atol=1e-7)
317321

318322

319323
def test_scale_range_axes(tid: MemberId):
@@ -363,4 +367,4 @@ def test_sigmoid(tid: MemberId):
363367
sigmoid(sample)
364368

365369
exp = xr.DataArray(1.0 / (1 + np.exp(-np_data)), dims=axes)
366-
xr.testing.assert_allclose(exp, sample.members[tid].data)
370+
xr.testing.assert_allclose(exp, sample.members[tid].data, rtol=1e-6, atol=1e-7)

0 commit comments

Comments
 (0)