Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc/source/nifti_images.rst
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ You can get and set the qform affine using the equivalent methods to those for
the sform: ``get_qform()``, ``set_qform()``.

>>> n1_header.get_qform(coded=True)
(array([[ -2. , 0. , 0. , 117.86],
[ -0. , 1.97, -0.36, -35.72],
(array([[ -2. , 0. , -0. , 117.86],
[ 0. , 1.97, -0.36, -35.72],
[ 0. , 0.32, 2.17, -7.25],
[ 0. , 0. , 0. , 1. ]]), 1)

Expand Down
2 changes: 1 addition & 1 deletion nibabel/nifti1.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ class Nifti1Header(SpmAnalyzeHeader):
single_magic = b'n+1'

# Quaternion threshold near 0, based on float32 precision
quaternion_threshold = -np.finfo(np.float32).eps * 3
quaternion_threshold = np.finfo(np.float32).eps * 3

def __init__(self, binaryblock=None, endianness=None, check=True, extensions=()):
"""Initialize header from binary data block and extensions"""
Expand Down
18 changes: 9 additions & 9 deletions nibabel/quaternions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def fillpositive(xyz, w2_thresh=None):
xyz : iterable
iterable containing 3 values, corresponding to quaternion x, y, z
w2_thresh : None or float, optional
threshold to determine if w squared is really negative.
threshold to determine if w squared is non-zero.
If None (default) then w2_thresh set equal to
``-np.finfo(xyz.dtype).eps``, if possible, otherwise
``-np.finfo(np.float64).eps``
3 * ``np.finfo(xyz.dtype).eps``, if possible, otherwise
3 * ``np.finfo(np.float64).eps``

Returns
-------
Expand Down Expand Up @@ -89,17 +89,17 @@ def fillpositive(xyz, w2_thresh=None):
# If necessary, guess precision of input
if w2_thresh is None:
try: # trap errors for non-array, integer array
w2_thresh = -np.finfo(xyz.dtype).eps * 3
w2_thresh = np.finfo(xyz.dtype).eps * 3
except (AttributeError, ValueError):
w2_thresh = -FLOAT_EPS * 3
w2_thresh = FLOAT_EPS * 3
# Use maximum precision
xyz = np.asarray(xyz, dtype=MAX_FLOAT)
# Calculate w
w2 = 1.0 - np.dot(xyz, xyz)
if w2 < 0:
if w2 < w2_thresh:
raise ValueError(f'w2 should be positive, but is {w2:e}')
w2 = 1.0 - xyz @ xyz
if np.abs(w2) < np.abs(w2_thresh):
w = 0
elif w2 < 0:
raise ValueError(f'w2 should be positive, but is {w2:e}')
else:
w = np.sqrt(w2)
return np.r_[w, xyz]
Expand Down
102 changes: 76 additions & 26 deletions nibabel/tests/test_quaternions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,40 @@
from .. import eulerangles as nea
from .. import quaternions as nq


def norm(vec):
# Return unit vector with same orientation as input vector
return vec / np.sqrt(vec @ vec)


def gen_vec(dtype):
# Generate random 3-vector in [-1, 1]^3
rand = np.random.default_rng()
return rand.uniform(low=-1.0, high=1.0, size=(3,)).astype(dtype)


# Example rotations
eg_rots = []
params = (-pi, pi, pi / 2)
zs = np.arange(*params)
ys = np.arange(*params)
xs = np.arange(*params)
for z in zs:
for y in ys:
for x in xs:
eg_rots.append(nea.euler2mat(z, y, x))
eg_rots = [
nea.euler2mat(z, y, x)
for z in np.arange(-pi, pi, pi / 2)
for y in np.arange(-pi, pi, pi / 2)
for x in np.arange(-pi, pi, pi / 2)
]

# Example quaternions (from rotations)
eg_quats = []
for M in eg_rots:
eg_quats.append(nq.mat2quat(M))
eg_quats = [nq.mat2quat(M) for M in eg_rots]
# M, quaternion pairs
eg_pairs = list(zip(eg_rots, eg_quats))

# Set of arbitrary unit quaternions
unit_quats = set()
params = range(-2, 3)
for w in params:
for x in params:
for y in params:
for z in params:
q = (w, x, y, z)
Nq = np.sqrt(np.dot(q, q))
if not Nq == 0:
q = tuple([e / Nq for e in q])
unit_quats.add(q)
unit_quats = set(
tuple(norm(np.r_[w, x, y, z]))
for w in range(-2, 3)
for x in range(-2, 3)
for y in range(-2, 3)
for z in range(-2, 3)
if (w, x, y, z) != (0, 0, 0, 0)
)


def test_fillpos():
Expand All @@ -69,6 +74,51 @@ def test_fillpos():
assert wxyz[0] == 0.0


@pytest.mark.parametrize('dtype', ('f4', 'f8'))
def test_fillpositive_plus_minus_epsilon(dtype):
# Deterministic test for fillpositive threshold
# We are trying to fill (x, y, z) with a w such that |(w, x, y, z)| == 1
# If |(x, y, z)| is slightly off one, w should still be 0
nptype = np.dtype(dtype).type

# Obviously, |(x, y, z)| == 1
baseline = np.array([0, 0, 1], dtype=dtype)

# Obviously, |(x, y, z)| ~ 1
plus = baseline * nptype(1 + np.finfo(dtype).eps)
minus = baseline * nptype(1 - np.finfo(dtype).eps)

assert nq.fillpositive(plus)[0] == 0.0
assert nq.fillpositive(minus)[0] == 0.0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthew-brett How's this for a deterministic test?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice - but - consider pusing to threshold to show where it fails?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. It actually fails at 2*eps, since the error compounds. Updated.


# |(x, y, z)| > 1, no real solutions
plus = baseline * nptype(1 + 2 * np.finfo(dtype).eps)
with pytest.raises(ValueError):
nq.fillpositive(plus)

# |(x, y, z)| < 1, two real solutions, we choose positive
minus = baseline * nptype(1 - 2 * np.finfo(dtype).eps)
assert nq.fillpositive(minus)[0] > 0.0


@pytest.mark.parametrize('dtype', ('f4', 'f8'))
def test_fillpositive_simulated_error(dtype):
# Nondeterministic test for fillpositive threshold
# Create random vectors, normalize to unit length, and count on floating point
# error to result in magnitudes larger/smaller than one
# This is to simulate cases where a unit quaternion with w == 0 would be encoded
# as xyz with small error, and we want to recover the w of 0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still like this as a bit of a stress test that will hit more realistic rounding errors, and added a comment to explain more. Can drop it if it seems excessive.


# Permit 1 epsilon per value (default, but make explicit here)
w2_thresh = 3 * np.finfo(dtype).eps

pos_error = neg_error = False
for _ in range(50):
xyz = norm(gen_vec(dtype))

assert nq.fillpositive(xyz, w2_thresh)[0] == 0.0


def test_conjugate():
# Takes sequence
cq = nq.conjugate((1, 0, 0, 0))
Expand Down Expand Up @@ -125,7 +175,7 @@ def test_norm():
def test_mult(M1, q1, M2, q2):
# Test that quaternion * same as matrix *
q21 = nq.mult(q2, q1)
assert_array_almost_equal, np.dot(M2, M1), nq.quat2mat(q21)
assert_array_almost_equal, M2 @ M1, nq.quat2mat(q21)


@pytest.mark.parametrize('M, q', eg_pairs)
Expand All @@ -146,7 +196,7 @@ def test_eye():
@pytest.mark.parametrize('M, q', eg_pairs)
def test_qrotate(vec, M, q):
vdash = nq.rotate_vector(vec, q)
vM = np.dot(M, vec)
vM = M @ vec
assert_array_almost_equal(vdash, vM)


Expand Down Expand Up @@ -179,6 +229,6 @@ def test_angle_axis():
nq.nearly_equivalent(q, q2)
aa_mat = nq.angle_axis2mat(theta, vec)
assert_array_almost_equal(aa_mat, M)
unit_vec = vec / np.sqrt(vec.dot(vec))
unit_vec = norm(vec)
aa_mat2 = nq.angle_axis2mat(theta, unit_vec, is_normalized=True)
assert_array_almost_equal(aa_mat2, M)