Skip to content

Commit 686dad9

Browse files
committed
raising exception before iterations stuck in indefinite loop, tests
1 parent b5250c5 commit 686dad9

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

httomolibgpu/recon/rotation.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -713,10 +713,10 @@ def _calculate_curvature(list_metric):
713713

714714
# work mostly on CPU here - we have very small arrays here
715715
list1 = cp.asnumpy(list_metric[min_pos - radi : min_pos + radi + 1])
716-
list1[np.isnan(list1)] = list1[~np.isnan(list1)].mean()
717-
list1[np.isinf(list1)] = list1[~np.isinf(list1)].mean()
718-
719-
# afact1 = np.polyfit(np.arange(0, 2 * radi + 1), list1, 2)[0]
716+
if not all(map(np.isfinite, list1)):
717+
raise ValueError(
718+
"The list of metrics (list1) contains nan's or infs. Check your input data"
719+
)
720720

721721
series1 = Polynomial.fit(np.arange(0, 2 * radi + 1), list1, deg=2)
722722
afact1 = series1.convert().coef[-1]
@@ -725,7 +725,10 @@ def _calculate_curvature(list_metric):
725725
list2[np.isnan(list2)] = list2[~np.isnan(list2)].mean()
726726
list2[np.isinf(list2)] = list2[~np.isinf(list2)].mean()
727727

728-
# (afact2, bfact2, _) = np.polyfit(np.arange(min_pos - 1, min_pos + 2), list2, 2)
728+
if not all(map(np.isfinite, list2)):
729+
raise ValueError(
730+
"The list of metrics (list2) contains nan's or infs. Check your input data"
731+
)
729732

730733
series2 = Polynomial.fit(np.arange(min_pos - 1, min_pos + 2), list2, deg=2)
731734
afact2 = series2.convert().coef[-1]

tests/test_recon/test_rotation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,17 @@ def test_find_center_360_1D_raises(data):
112112
find_center_360(cp.ones(10))
113113

114114

115+
def test_find_center_360_NaN_infs_raises(data, flats, darks):
116+
#: find_center_360 raises if the data with NaNs or Infs given
117+
data = data.astype(cp.float32)
118+
data[:] = cp.inf
119+
with pytest.raises(ValueError):
120+
find_center_360(data)
121+
data[:] = cp.nan
122+
with pytest.raises(ValueError):
123+
find_center_360(data)
124+
125+
115126
@pytest.mark.parametrize("norm", [False, True], ids=["no_normalise", "normalise"])
116127
@pytest.mark.parametrize("overlap", [False, True], ids=["no_overlap", "overlap"])
117128
@pytest.mark.parametrize("denoise", [False, True], ids=["no_denoise", "denoise"])

0 commit comments

Comments
 (0)