Skip to content

Commit 44442f9

Browse files
authored
Merge branch 'main' into datavalidator
2 parents ba72baf + 2b32ba7 commit 44442f9

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

httomolibgpu/recon/rotation.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"""Modules for finding the axis of rotation for 180 and 360 degrees scans"""
2222

2323
import numpy as np
24+
from numpy.polynomial import Polynomial
2425
from httomolibgpu import cupywrapper
2526

2627
cp = cupywrapper.cp
@@ -718,9 +719,23 @@ def _calculate_curvature(list_metric):
718719

719720
# work mostly on CPU here - we have very small arrays here
720721
list1 = cp.asnumpy(list_metric[min_pos - radi : min_pos + radi + 1])
721-
afact1 = np.polyfit(np.arange(0, 2 * radi + 1), list1, 2)[0]
722+
if not all(map(np.isfinite, list1)):
723+
raise ValueError(
724+
"The list of metrics (list1) contains nan's or infs. Check your input data"
725+
)
726+
727+
series1 = Polynomial.fit(np.arange(0, 2 * radi + 1), list1, deg=2)
728+
afact1 = series1.convert().coef[-1]
729+
722730
list2 = cp.asnumpy(list_metric[min_pos - 1 : min_pos + 2])
723-
(afact2, bfact2, _) = np.polyfit(np.arange(min_pos - 1, min_pos + 2), list2, 2)
731+
if not all(map(np.isfinite, list2)):
732+
raise ValueError(
733+
"The list of metrics (list2) contains nan's or infs. Check your input data"
734+
)
735+
736+
series2 = Polynomial.fit(np.arange(min_pos - 1, min_pos + 2), list2, deg=2)
737+
afact2 = series2.convert().coef[-1]
738+
bfact2 = series2.convert().coef[-1 - 1]
724739

725740
curvature = np.abs(afact1)
726741
if afact2 != 0.0:

tests/test_recon/test_rotation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,17 @@ def test_find_center_360_1D_raises(data):
112112
find_center_360(cp.ones(10))
113113

114114

115+
def test_find_center_360_NaN_infs_raises(data, flats, darks):
116+
#: find_center_360 raises if the data with NaNs or Infs given
117+
data = data.astype(cp.float32)
118+
data[:] = cp.inf
119+
with pytest.raises(ValueError):
120+
find_center_360(data)
121+
data[:] = cp.nan
122+
with pytest.raises(ValueError):
123+
find_center_360(data)
124+
125+
115126
@pytest.mark.parametrize("norm", [False, True], ids=["no_normalise", "normalise"])
116127
@pytest.mark.parametrize("overlap", [False, True], ids=["no_overlap", "overlap"])
117128
@pytest.mark.parametrize("denoise", [False, True], ids=["no_denoise", "denoise"])

0 commit comments

Comments
 (0)