Skip to content
110 changes: 98 additions & 12 deletions spatialmath/base/quaternions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
import math
import numpy as np
import spatialmath.base as smb
from spatialmath.base.argcheck import getunit
from spatialmath.base.types import *
import scipy.interpolate as interpolate
from typing import Optional
from functools import lru_cache

_eps = np.finfo(np.float64).eps


def qeye() -> QuaternionArray:
"""
Create an identity quaternion
Expand Down Expand Up @@ -843,29 +846,112 @@ def qslerp(
return q0


def qrand() -> UnitQuaternionArray:
def _compute_cdf_sin_squared(theta: float):
"""
Random unit-quaternion
Computes the CDF for the distribution of angular magnitude for uniformly sampled rotations.

:arg theta: angular magnitude
:rtype: float
:return: cdf of a given angular magnitude
:rtype: float

Helper function for uniform sampling of rotations with constrained angular magnitude.
This function returns the integral of the pdf of angular magnitudes (2/pi * sin^2(theta/2)).
"""
return (theta - np.sin(theta)) / np.pi

@lru_cache(maxsize=1)
def _generate_inv_cdf_sin_squared_interp(num_interpolation_points: int = 256) -> interpolate.interp1d:
"""
Computes an interpolation function for the inverse CDF of the distribution of angular magnitude.

:arg num_interpolation_points: number of points to use in the interpolation function
:rtype: int
:return: interpolation function for the inverse cdf of a given angular magnitude
:rtype: interpolate.interp1d

Helper function for uniform sampling of rotations with constrained angular magnitude.
This function returns interpolation function for the inverse of the integral of the
pdf of angular magnitudes (2/pi * sin^2(theta/2)), which is not analytically defined.
"""
cdf_sin_squared_interp_angles = np.linspace(0, np.pi, num_interpolation_points)
cdf_sin_squared_interp_values = _compute_cdf_sin_squared(cdf_sin_squared_interp_angles)
return interpolate.interp1d(cdf_sin_squared_interp_values, cdf_sin_squared_interp_angles)

def _compute_inv_cdf_sin_squared(x: ArrayLike, num_interpolation_points: int = 256) -> ArrayLike:
"""
Computes the inverse CDF of the distribution of angular magnitude.

:arg x: value for cdf of angular magnitudes
:rtype: ArrayLike
:arg num_interpolation_points: number of points to use in the interpolation function
:rtype: int
:return: angular magnitude associate with cdf value
:rtype: ArrayLike

Helper function for uniform sampling of rotations with constrained angular magnitude.
This function returns the angle associated with the cdf value derived form integral of
the pdf of angular magnitudes (2/pi * sin^2(theta/2)), which is not analytically defined.
"""
inv_cdf_sin_squared_interp = _generate_inv_cdf_sin_squared_interp(num_interpolation_points)
return inv_cdf_sin_squared_interp(x)

def qrand(theta_range:Optional[ArrayLike2] = None, unit: str = "rad", num_interpolation_points: int = 256) -> UnitQuaternionArray:
"""
Random unit-quaternion

:arg theta_range: angular magnitude range [min,max], defaults to None.
:type xrange: 2-element sequence, optional
:arg unit: angular units: 'rad' [default], or 'deg'
:type unit: str
:arg num_interpolation_points: number of points to use in the interpolation function
:rtype: int
:arg num_interpolation_points: number of points to use in the interpolation function
:rtype: int
:return: random unit-quaternion
:rtype: ndarray(4)

Computes a uniformly distributed random unit-quaternion which can be
considered equivalent to a random SO(3) rotation.
Computes a uniformly distributed random unit-quaternion, with in a maximum
angular magnitude, which can be considered equivalent to a random SO(3) rotation.

.. runblock:: pycon

>>> from spatialmath.base import qrand, qprint
>>> qprint(qrand())
"""
u = np.random.uniform(low=0, high=1, size=3) # get 3 random numbers in [0,1]
return np.r_[
math.sqrt(1 - u[0]) * math.sin(2 * math.pi * u[1]),
math.sqrt(1 - u[0]) * math.cos(2 * math.pi * u[1]),
math.sqrt(u[0]) * math.sin(2 * math.pi * u[2]),
math.sqrt(u[0]) * math.cos(2 * math.pi * u[2]),
]
if theta_range is not None:
theta_range = getunit(theta_range, unit)

if(theta_range[0] < 0 or theta_range[1] > np.pi or theta_range[0] > theta_range[1]):
ValueError('Invalid angular range. Must be within the range[0, pi].'
+ f' Recieved {theta_range}.')

# Sample axis and angle independently, respecting the CDF of the
# angular magnitude under uniform sampling.

# Sample angle using inverse transform sampling based on CDF
# of the angular distribution (2/pi * sin^2(theta/2))
theta = _compute_inv_cdf_sin_squared(
np.random.uniform(
low=_compute_cdf_sin_squared(theta_range[0]),
high=_compute_cdf_sin_squared(theta_range[1]),
),
num_interpolation_points=num_interpolation_points,
)
# Sample axis uniformly using 3D normal distributed
v = np.random.randn(3)
v /= np.linalg.norm(v)

return np.r_[math.cos(theta / 2), (math.sin(theta / 2) * v)]
else:
u = np.random.uniform(low=0, high=1, size=3) # get 3 random numbers in [0,1]
return np.r_[
math.sqrt(1 - u[0]) * math.sin(2 * math.pi * u[1]),
math.sqrt(1 - u[0]) * math.cos(2 * math.pi * u[1]),
math.sqrt(u[0]) * math.sin(2 * math.pi * u[2]),
math.sqrt(u[0]) * math.cos(2 * math.pi * u[2]),
]


def qmatrix(q: ArrayLike4) -> R4x4:
"""
Expand Down
18 changes: 14 additions & 4 deletions spatialmath/pose3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from spatialmath.twist import Twist3

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from spatialmath.quaternion import UnitQuaternion

Expand Down Expand Up @@ -455,12 +455,16 @@ def Rz(cls, theta, unit: str = "rad") -> Self:
return cls([smb.rotz(x, unit=unit) for x in smb.getvector(theta)], check=False)

@classmethod
def Rand(cls, N: int = 1) -> Self:
def Rand(cls, N: int = 1, *, theta_range:Optional[ArrayLike2] = None, unit: str = "rad") -> Self:
"""
Construct a new SO(3) from random rotation

:param N: number of random rotations
:type N: int
:param theta_range: angular magnitude range [min,max], defaults to None.
:type xrange: 2-element sequence, optional
:param unit: angular units: 'rad' [default], or 'deg'
:type unit: str
:return: SO(3) rotation matrix
:rtype: SO3 instance

Expand All @@ -477,7 +481,7 @@ def Rand(cls, N: int = 1) -> Self:

:seealso: :func:`spatialmath.quaternion.UnitQuaternion.Rand`
"""
return cls([smb.q2r(smb.qrand()) for _ in range(0, N)], check=False)
return cls([smb.q2r(smb.qrand(theta_range=theta_range, unit=unit)) for _ in range(0, N)], check=False)

@overload
@classmethod
Expand Down Expand Up @@ -1517,6 +1521,8 @@ def Rand(
xrange: Optional[ArrayLike2] = (-1, 1),
yrange: Optional[ArrayLike2] = (-1, 1),
zrange: Optional[ArrayLike2] = (-1, 1),
theta_range:Optional[ArrayLike2] = None,
unit: str = "rad",
) -> SE3: # pylint: disable=arguments-differ
"""
Create a random SE(3)
Expand All @@ -1527,6 +1533,10 @@ def Rand(
:type yrange: 2-element sequence, optional
:param zrange: z-axis range [min,max], defaults to [-1, 1]
:type zrange: 2-element sequence, optional
:param theta_range: angular magnitude range [min,max], defaults to None -> [0,pi].
:type xrange: 2-element sequence, optional
:param unit: angular units: 'rad' [default], or 'deg'
:type unit: str
:param N: number of random transforms
:type N: int
:return: SE(3) matrix
Expand Down Expand Up @@ -1557,7 +1567,7 @@ def Rand(
Z = np.random.uniform(
low=zrange[0], high=zrange[1], size=N
) # random values in the range
R = SO3.Rand(N=N)
R = SO3.Rand(N=N, theta_range=theta_range, unit=unit)
return cls(
[smb.transl(x, y, z) @ smb.r2t(r.A) for (x, y, z, r) in zip(X, Y, Z, R)],
check=False,
Expand Down
8 changes: 6 additions & 2 deletions spatialmath/quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,12 +1225,16 @@ def Rz(cls, angles: ArrayLike, unit: Optional[str] = "rad") -> UnitQuaternion:
)

@classmethod
def Rand(cls, N: int = 1) -> UnitQuaternion:
def Rand(cls, N: int = 1, *, theta_range:Optional[ArrayLike2] = None, unit: str = "rad") -> UnitQuaternion:
"""
Construct a new random unit quaternion

:param N: number of random rotations
:type N: int
:param theta_range: angular magnitude range [min,max], defaults to None -> [0,pi].
:type xrange: 2-element sequence, optional
:param unit: angular units: 'rad' [default], or 'deg'
:type unit: str
:return: random unit-quaternion
:rtype: UnitQuaternion instance

Expand All @@ -1248,7 +1252,7 @@ def Rand(cls, N: int = 1) -> UnitQuaternion:

:seealso: :meth:`UnitQuaternion.Rand`
"""
return cls([smb.qrand() for i in range(0, N)], check=False)
return cls([smb.qrand(theta_range=theta_range, unit=unit) for i in range(0, N)], check=False)

@classmethod
def Eul(cls, *angles: List[float], unit: Optional[str] = "rad") -> UnitQuaternion:
Expand Down
18 changes: 17 additions & 1 deletion tests/test_pose3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,19 @@ def test_constructor(self):
array_compare(R, np.eye(3))
self.assertIsInstance(R, SO3)

np.random.seed(32)
# random
R = SO3.Rand()
nt.assert_equal(len(R), 1)
self.assertIsInstance(R, SO3)

# random constrained
R = SO3.Rand(theta_range=(0.1, 0.7))
self.assertIsInstance(R, SO3)
self.assertEqual(R.A.shape, (3, 3))
self.assertLessEqual(R.angvec()[0], 0.7)
self.assertGreaterEqual(R.angvec()[0], 0.1)

# copy constructor
R = SO3.Rx(pi / 2)
R2 = SO3(R)
Expand Down Expand Up @@ -816,12 +824,13 @@ def test_constructor(self):
array_compare(R, np.eye(4))
self.assertIsInstance(R, SE3)

np.random.seed(65)
# random
R = SE3.Rand()
nt.assert_equal(len(R), 1)
self.assertIsInstance(R, SE3)

# random
# random
T = SE3.Rand()
R = T.R
t = T.t
Expand All @@ -847,6 +856,13 @@ def test_constructor(self):
nt.assert_equal(TT.y, ones * t[1])
nt.assert_equal(TT.z, ones * t[2])

# random constrained
T = SE3.Rand(theta_range=(0.1, 0.7))
self.assertIsInstance(T, SE3)
self.assertEqual(T.A.shape, (4, 4))
self.assertLessEqual(T.angvec()[0], 0.7)
self.assertGreaterEqual(T.angvec()[0], 0.1)

# copy constructor
R = SE3.Rx(pi / 2)
R2 = SE3(R)
Expand Down
13 changes: 13 additions & 0 deletions tests/test_quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ def test_constructor_variants(self):
nt.assert_array_almost_equal(
UnitQuaternion.Rz(-90, "deg").vec, np.r_[1, 0, 0, -1] / math.sqrt(2)
)

np.random.seed(73)
q = UnitQuaternion.Rand(theta_range=(0.1, 0.7))
self.assertIsInstance(q, UnitQuaternion)
self.assertLessEqual(q.angvec()[0], 0.7)
self.assertGreaterEqual(q.angvec()[0], 0.1)


q = UnitQuaternion.Rand(theta_range=(0.1, 0.7))
self.assertIsInstance(q, UnitQuaternion)
self.assertLessEqual(q.angvec()[0], 0.7)
self.assertGreaterEqual(q.angvec()[0], 0.1)


def test_constructor(self):
qcompare(UnitQuaternion(), [1, 0, 0, 0])
Expand Down
Loading