Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
repos:
# - repo: https://github.com/charliermarsh/ruff-pre-commit
# # Ruff version.
# rev: 'v0.1.0'
# hooks:
# - id: ruff
# args: ['--fix', '--config', 'pyproject.toml']
- repo: https://github.com/charliermarsh/ruff-pre-commit
# Ruff version.
rev: 'v0.1.0'
hooks:
- id: ruff
args: ['--fix', '--config', 'pyproject.toml']

- repo: https://github.com/psf/black
rev: 23.10.0
Expand All @@ -14,6 +14,21 @@ repos:
args: ['--config', 'pyproject.toml']
verbose: true

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: end-of-file-fixer
- id: debug-statements # Ensure we don't commit `import pdb; pdb.set_trace()`
exclude: |
(?x)^(
docker/ros/web/static/.*|
)$
- id: trailing-whitespace
exclude: |
(?x)^(
docker/ros/web/static/.*|
(.*/).*\.patch|
)$
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.6.1
# hooks:
Expand Down
4 changes: 3 additions & 1 deletion spatialmath/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from spatialmath.quaternion import Quaternion, UnitQuaternion
from spatialmath.DualQuaternion import DualQuaternion, UnitDualQuaternion
from spatialmath.spline import BSplineSE3
from spatialmath.spline import BSplineSE3, InterpSplineSE3, SplineFit

# from spatialmath.Plucker import *
# from spatialmath import base as smb
Expand Down Expand Up @@ -45,6 +45,8 @@
"Polygon2",
"Ellipse",
"BSplineSE3",
"InterpSplineSE3",
"SplineFit"
]

try:
Expand Down
2 changes: 1 addition & 1 deletion spatialmath/base/animate.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def update(frame, animation):
if isinstance(frame, float):
# passed a single transform, interpolate it
T = smb.trinterp(start=self.start, end=self.end, s=frame)
elif isinstance(frame, NDArray):
elif isinstance(frame, np.ndarray):
# type is SO3Array or SE3Array when Animate.trajectory is not None
T = frame
else:
Expand Down
273 changes: 234 additions & 39 deletions spatialmath/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,244 @@
# MIT Licence, see details in top-level file: LICENCE

"""
Classes for parameterizing a trajectory in SE3 with B-splines.

Copies parts of the API from scipy's B-spline class.
Classes for parameterizing a trajectory in SE3 with splines.
"""

from typing import Any, Dict, List, Optional
from scipy.interpolate import BSpline
from spatialmath import SE3
import numpy as np
from abc import ABC, abstractmethod
from functools import cached_property
from typing import List, Optional, Tuple

import matplotlib.pyplot as plt
from spatialmath.base.transforms3d import tranimate, trplot
import numpy as np
from scipy.interpolate import BSpline, CubicSpline
from scipy.spatial.transform import Rotation, RotationSpline

from spatialmath import SE3, SO3, Twist3
from spatialmath.base.transforms3d import tranimate


class SplineSE3(ABC):
def __init__(self) -> None:
self.control_poses: SE3

@abstractmethod
def __call__(self, t: float) -> SE3:
pass

def visualize(
self,
sample_times: List[float],
pose_marker_length: float = 0.2,
animate: bool = False,
repeat: bool = True,
ax: Optional[plt.Axes] = None,
input_trajectory: Optional[List[SE3]] = None,
) -> None:
"""Displays an animation of the trajectory with the control poses against an optional input trajectory.

Args:
times: which times to sample the spline at and plot
"""
if ax is None:
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(projection="3d")

samples = [self(t) for t in sample_times]
if not animate:
pos = np.array([pose.t for pose in samples])
ax.plot(
pos[:, 0], pos[:, 1], pos[:, 2], "c", linewidth=1.0
) # plot spline fit

pos = np.array([pose.t for pose in self.control_poses])
ax.plot(pos[:, 0], pos[:, 1], pos[:, 2], "r*") # plot control_poses

if input_trajectory is not None:
pos = np.array([pose.t for pose in input_trajectory])
ax.plot(
pos[:, 0], pos[:, 1], pos[:, 2], "go", fillstyle="none"
) # plot compare to input poses

if animate:
tranimate(
samples, length=pose_marker_length, wait=True, repeat=repeat
) # animate pose along trajectory
else:
plt.show()

class BSplineSE3:

class InterpSplineSE3(SplineSE3):
"""Class for an interpolated trajectory in SE3, as a function of time, through control_poses with a cubic spline.

A combination of scipy.interpolate.CubicSpline and scipy.spatial.transform.RotationSpline (itself also cubic)
under the hood.
"""

_e = 1e-12

def __init__(
self,
timepoints: List[float],
control_poses: List[SE3],
*,
normalize_time: bool = False,
bc_type: str = "not-a-knot", # not-a-knot is scipy default; None is invalid
) -> None:
"""Construct a InterpSplineSE3 object

Extends the scipy CubicSpline object
https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.CubicSpline.html#cubicspline

Args :
timepoints : list of times corresponding to provided poses
control_poses : list of SE3 objects that govern the shape of the spline.
normalize_time : flag to map times into the range [0, 1]
bc_type : boundary condition provided to scipy CubicSpline backend.
string options: ["not-a-knot" (default), "clamped", "natural", "periodic"].
For tuple options and details see the scipy docs link above.
"""
super().__init__()
self.control_poses = control_poses
self.timepoints = np.array(timepoints)

if self.timepoints[-1] < self._e:
raise ValueError(
"Difference between start and end timepoints is less than {self._e}"
)

if len(self.control_poses) != len(self.timepoints):
raise ValueError("Length of control_poses and timepoints must be equal.")

if len(self.timepoints) < 2:
raise ValueError("Need at least 2 data points to make a trajectory.")

if normalize_time:
self.timepoints = self.timepoints - self.timepoints[0]
self.timepoints = self.timepoints / self.timepoints[-1]

self.spline_xyz = CubicSpline(
self.timepoints,
np.array([pose.t for pose in self.control_poses]),
bc_type=bc_type,
)
self.spline_so3 = RotationSpline(
self.timepoints,
Rotation.from_matrix(np.array([(pose.R) for pose in self.control_poses])),
)

def __call__(self, t: float) -> SE3:
"""Compute function value at t.
Return:
pose: SE3
"""
return SE3.Rt(t=self.spline_xyz(t), R=self.spline_so3(t).as_matrix())

def derivative(self, t: float) -> Twist3:
linear_vel = self.spline_xyz.derivative()(t)
angular_vel = self.spline_so3(
t, 1
) # 1 is angular rate, 2 is angular acceleration
return Twist3(linear_vel, angular_vel)


class SplineFit:
"""A general class to fit various SE3 splines to data."""

def __init__(
self,
time_data: List[float],
pose_data: List[SE3],
) -> None:
self.time_data = time_data
self.pose_data = pose_data

self.xyz_data = np.array([pose.t for pose in pose_data])
self.so3_data = Rotation.from_matrix(np.array([(pose.R) for pose in pose_data]))

self.spline: Optional[SplineSE3] = None

def stochastic_downsample_interpolation(
self,
epsilon_xyz: float = 1e-3,
epsilon_angle: float = 1e-1,
normalize_time: bool = True,
bc_type: str = "not-a-knot",
) -> Tuple[InterpSplineSE3, List[int]]:
"""
Uses a random dropout heuristic to downsample a trajectory with an interpolated spline.

This code does not ensure the global fit is within epsilon_xyz and epsilon_angle.

Return:
downsampled interpolating spline,
list of removed indices from input data
"""
spline = InterpSplineSE3(
self.time_data,
self.pose_data,
normalize_time=normalize_time,
bc_type=bc_type,
)
chosen_indices: set[int] = set()
interpolation_indices = list(range(len(self.pose_data)))
interpolation_indices.remove(0)
interpolation_indices.remove(len(self.pose_data) - 1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use range(1, len(self.pose_data) - 1)? Or `range(len(self.pose_data))[1:-1]? The current way seems a bit unconventional but I guess it doesn't matter.


for _ in range(len(self.time_data) - 2): # you must have at least 2 indices

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May as well use len(interpolation_indices) and drop the - 2

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the length of interpolation indices is changing as the loop executes

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

your suggestion might work anyway, but that's why i didn't code it like that originally

choices = list(set(interpolation_indices).difference(chosen_indices))

index = np.random.choice(choices)

chosen_indices.add(index)
interpolation_indices.remove(index)

spline.spline_xyz = CubicSpline(
self.time_data[interpolation_indices],
self.xyz_data[interpolation_indices],
)
spline.spline_so3 = RotationSpline(
self.time_data[interpolation_indices],
self.so3_data[interpolation_indices],
)

time = self.time_data[index]
angular_error = SO3(self.pose_data[index]).angdist(
SO3(spline.spline_so3(time).as_matrix())
)
euclidean_error = np.linalg.norm(
self.pose_data[index].t - spline.spline_xyz(time)
)
if (angular_error > epsilon_angle) or (euclidean_error > epsilon_xyz):
interpolation_indices.insert(
int(np.searchsorted(interpolation_indices, index, side="right")),
index,
)

self.spline = spline
return spline, interpolation_indices

def max_angular_error(self) -> float:
return np.max(self.angular_errors)

@cached_property
def angular_errors(self) -> List[float]:
return [
pose.angdist(self.spline(t))
for pose, t in zip(self.pose_data, self.time_data)
]

def max_euclidean_error(self) -> float:
return np.max(self.euclidean_errors)

@cached_property
def euclidean_errors(self) -> List[float]:
return [
np.linalg.norm(pose.t - self.spline(t).t)
for pose, t in zip(self.pose_data, self.time_data)
]


class BSplineSE3(SplineSE3):
"""A class to parameterize a trajectory in SE3 with a 6-dimensional B-spline.

The SE3 control poses are converted to se3 twists (the lie algebra) and a B-spline
Expand All @@ -41,7 +265,7 @@ def __init__(
at a given t input. If none, they are automatically, uniformly generated based on number of control poses and
degree of spline.
"""

super().__init__()
self.control_poses = control_poses

# a matrix where each row is a control pose as a twist
Expand Down Expand Up @@ -74,32 +298,3 @@ def __call__(self, t: float) -> SE3:
"""
twist = np.hstack([spline(t) for spline in self.splines])
return SE3.Exp(twist)

def visualize(
self,
num_samples: int,
length: float = 1.0,
repeat: bool = False,
ax: Optional[plt.Axes] = None,
kwargs_trplot: Dict[str, Any] = {"color": "green"},
kwargs_tranimate: Dict[str, Any] = {"wait": True},
kwargs_plot: Dict[str, Any] = {},
) -> None:
"""Displays an animation of the trajectory with the control poses."""
out_poses = [self(t) for t in np.linspace(0, 1, num_samples)]
x = [pose.x for pose in out_poses]
y = [pose.y for pose in out_poses]
z = [pose.z for pose in out_poses]

if ax is None:
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(projection="3d")

trplot(
[np.array(self.control_poses)], ax=ax, length=length, **kwargs_trplot
) # plot control points
ax.plot(x, y, z, **kwargs_plot) # plot x,y,z trajectory

tranimate(
out_poses, repeat=repeat, length=length, **kwargs_tranimate
) # animate pose along trajectory
Loading
Loading