Skip to content

Commit a269308

Browse files
committed
Merge branch 'io_orientation' into main-master
* io_orientation: TST: add tests for io_orientations thresholding RF: simplify allclose test DOC: update docstring to reflect new tolerance BF: fixing RS.shape + use allclose (was allzeros) ENH: making the svd tol matching MATLAB's tol
2 parents 5b037c3 + ebd999c commit a269308

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

nibabel/orientations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def io_orientation(affine, tol=None):
3737
threshold below which SVD values of the affine are considered zero. If
3838
`tol` is None, and ``S`` is an array with singular values for `affine`,
3939
and ``eps`` is the epsilon value for datatype of ``S``, then `tol` set to
40-
``S.max() * eps``.
40+
``S.max() * max((q, p)) * eps``
4141
4242
Returns
4343
-------
@@ -62,7 +62,7 @@ def io_orientation(affine, tol=None):
6262
P, S, Qs = npl.svd(RS)
6363
# Threshold the singular values to determine the rank.
6464
if tol is None:
65-
tol = S.max() * np.finfo(S.dtype).eps
65+
tol = S.max() * max(RS.shape) * np.finfo(S.dtype).eps
6666
keep = (S > tol)
6767
R = np.dot(P[:, keep], Qs[keep])
6868
# the matrix R is such that np.dot(R,R.T) is projection onto the
@@ -75,7 +75,7 @@ def io_orientation(affine, tol=None):
7575
ornt = np.ones((p, 2), dtype=np.int8) * np.nan
7676
for in_ax in range(p):
7777
col = R[:, in_ax]
78-
if not np.alltrue(np.equal(col, 0)):
78+
if not np.allclose(col, 0):
7979
out_ax = np.argmax(np.abs(col))
8080
ornt[in_ax, 0] = out_ax
8181
assert col[out_ax] != 0

nibabel/tests/test_orientations.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,33 @@ def test_io_orientation():
176176
[np.nan, np.nan],
177177
[np.nan, np.nan],
178178
[np.nan, np.nan]])
179+
# Test behavior of thresholding
180+
def_aff = np.array([[1., 1, 0, 0],
181+
[0, 0, 0, 0],
182+
[0, 0, 1, 0],
183+
[0, 0, 0, 1]])
184+
fail_tol = np.array([[0, 1],
185+
[np.nan, np.nan],
186+
[2, 1]])
187+
pass_tol = np.array([[0, 1],
188+
[1, 1],
189+
[2, 1]])
190+
eps = np.finfo(float).eps
191+
# Test that a Y axis appears as we increase the difference between the first
192+
# two columns
193+
for y_val, has_y in ((0, False),
194+
(eps, False),
195+
(eps * 5, False),
196+
(eps * 10, True),
197+
):
198+
def_aff[1, 1] = y_val
199+
res = pass_tol if has_y else fail_tol
200+
assert_array_equal(io_orientation(def_aff), res)
201+
# Test tol input argument
202+
def_aff[1, 1] = eps
203+
assert_array_equal(io_orientation(def_aff, tol=0), pass_tol)
204+
def_aff[1, 1] = eps * 10
205+
assert_array_equal(io_orientation(def_aff, tol=1e-5), fail_tol)
179206

180207

181208
def test_ornt2axcodes():

0 commit comments

Comments
 (0)