- 
                Notifications
    You must be signed in to change notification settings 
- Fork 96
Interpolating spline #141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Interpolating spline #141
Changes from 20 commits
183c854
              af175d3
              1e0f7ed
              578f4d3
              bce7f58
              3d99327
              806726e
              aa8148a
              d5f2c42
              ecc9b59
              2c5190f
              1b8fc78
              c7b7be6
              51f9204
              9634c24
              a2294ed
              4e7aa35
              551e473
              37181dd
              8dbc059
              02b1f52
              bf28479
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -2,20 +2,244 @@ | |
| # MIT Licence, see details in top-level file: LICENCE | ||
|  | ||
| """ | ||
| Classes for parameterizing a trajectory in SE3 with B-splines. | ||
|  | ||
| Copies parts of the API from scipy's B-spline class. | ||
| Classes for parameterizing a trajectory in SE3 with splines. | ||
| """ | ||
|  | ||
| from typing import Any, Dict, List, Optional | ||
| from scipy.interpolate import BSpline | ||
| from spatialmath import SE3 | ||
| import numpy as np | ||
| from abc import ABC, abstractmethod | ||
| from functools import cached_property | ||
| from typing import List, Optional, Tuple | ||
|  | ||
| import matplotlib.pyplot as plt | ||
| from spatialmath.base.transforms3d import tranimate, trplot | ||
| import numpy as np | ||
| from scipy.interpolate import BSpline, CubicSpline | ||
| from scipy.spatial.transform import Rotation, RotationSpline | ||
|  | ||
| from spatialmath import SE3, SO3, Twist3 | ||
| from spatialmath.base.transforms3d import tranimate | ||
|  | ||
|  | ||
| class SplineSE3(ABC): | ||
| def __init__(self) -> None: | ||
| self.control_poses: SE3 | ||
|  | ||
| @abstractmethod | ||
| def __call__(self, t: float) -> SE3: | ||
| pass | ||
|  | ||
| def visualize( | ||
| self, | ||
| sample_times: List[float], | ||
| pose_marker_length: float = 0.2, | ||
| animate: bool = False, | ||
| repeat: bool = True, | ||
| ax: Optional[plt.Axes] = None, | ||
| input_trajectory: Optional[List[SE3]] = None, | ||
| ) -> None: | ||
| """Displays an animation of the trajectory with the control poses against an optional input trajectory. | ||
|  | ||
| Args: | ||
| times: which times to sample the spline at and plot | ||
| """ | ||
| if ax is None: | ||
| fig = plt.figure(figsize=(10, 10)) | ||
| ax = fig.add_subplot(projection="3d") | ||
|  | ||
| samples = [self(t) for t in sample_times] | ||
| if not animate: | ||
| pos = np.array([pose.t for pose in samples]) | ||
| ax.plot( | ||
| pos[:, 0], pos[:, 1], pos[:, 2], "c", linewidth=1.0 | ||
| ) # plot spline fit | ||
|  | ||
| pos = np.array([pose.t for pose in self.control_poses]) | ||
| ax.plot(pos[:, 0], pos[:, 1], pos[:, 2], "r*") # plot control_poses | ||
|  | ||
| if input_trajectory is not None: | ||
| pos = np.array([pose.t for pose in input_trajectory]) | ||
| ax.plot( | ||
| pos[:, 0], pos[:, 1], pos[:, 2], "go", fillstyle="none" | ||
| ) # plot compare to input poses | ||
|  | ||
| if animate: | ||
| tranimate( | ||
| samples, length=pose_marker_length, wait=True, repeat=repeat | ||
| ) # animate pose along trajectory | ||
| else: | ||
| plt.show() | ||
|  | ||
| class BSplineSE3: | ||
|  | ||
| class InterpSplineSE3(SplineSE3): | ||
| """Class for an interpolated trajectory in SE3, as a function of time, through control_poses with a cubic spline. | ||
|  | ||
| A combination of scipy.interpolate.CubicSpline and scipy.spatial.transform.RotationSpline (itself also cubic) | ||
| under the hood. | ||
| """ | ||
|  | ||
| _e = 1e-12 | ||
|  | ||
| def __init__( | ||
| self, | ||
| timepoints: List[float], | ||
| control_poses: List[SE3], | ||
| *, | ||
| normalize_time: bool = False, | ||
| bc_type: str = "not-a-knot", # not-a-knot is scipy default; None is invalid | ||
| ) -> None: | ||
|         
                  myeatman-bdai marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| """Construct a InterpSplineSE3 object | ||
|  | ||
| Extends the scipy CubicSpline object | ||
| https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.CubicSpline.html#cubicspline | ||
|  | ||
| Args : | ||
| timepoints : list of times corresponding to provided poses | ||
| control_poses : list of SE3 objects that govern the shape of the spline. | ||
| normalize_time : flag to map times into the range [0, 1] | ||
| bc_type : boundary condition provided to scipy CubicSpline backend. | ||
| string options: ["not-a-knot" (default), "clamped", "natural", "periodic"]. | ||
| For tuple options and details see the scipy docs link above. | ||
| """ | ||
| super().__init__() | ||
| self.control_poses = control_poses | ||
| self.timepoints = np.array(timepoints) | ||
|  | ||
| if self.timepoints[-1] < self._e: | ||
| raise ValueError( | ||
| "Difference between start and end timepoints is less than {self._e}" | ||
| ) | ||
|  | ||
| if len(self.control_poses) != len(self.timepoints): | ||
| raise ValueError("Length of control_poses and timepoints must be equal.") | ||
|  | ||
| if len(self.timepoints) < 2: | ||
| raise ValueError("Need at least 2 data points to make a trajectory.") | ||
|  | ||
| if normalize_time: | ||
| self.timepoints = self.timepoints - self.timepoints[0] | ||
| self.timepoints = self.timepoints / self.timepoints[-1] | ||
|         
                  myeatman-bdai marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
|  | ||
| self.spline_xyz = CubicSpline( | ||
| self.timepoints, | ||
| np.array([pose.t for pose in self.control_poses]), | ||
| bc_type=bc_type, | ||
| ) | ||
| self.spline_so3 = RotationSpline( | ||
| self.timepoints, | ||
| Rotation.from_matrix(np.array([(pose.R) for pose in self.control_poses])), | ||
| ) | ||
|  | ||
| def __call__(self, t: float) -> SE3: | ||
|         
                  myeatman-bdai marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| """Compute function value at t. | ||
| Return: | ||
| pose: SE3 | ||
| """ | ||
| return SE3.Rt(t=self.spline_xyz(t), R=self.spline_so3(t).as_matrix()) | ||
|  | ||
| def derivative(self, t: float) -> Twist3: | ||
| linear_vel = self.spline_xyz.derivative()(t) | ||
| angular_vel = self.spline_so3( | ||
| t, 1 | ||
| ) # 1 is angular rate, 2 is angular acceleration | ||
| return Twist3(linear_vel, angular_vel) | ||
|  | ||
|  | ||
| class SplineFit: | ||
|         
                  myeatman-bdai marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| """A general class to fit various SE3 splines to data.""" | ||
|  | ||
| def __init__( | ||
| self, | ||
| time_data: List[float], | ||
| pose_data: List[SE3], | ||
| ) -> None: | ||
| self.time_data = time_data | ||
|         
                  myeatman-bdai marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| self.pose_data = pose_data | ||
|  | ||
| self.xyz_data = np.array([pose.t for pose in pose_data]) | ||
| self.so3_data = Rotation.from_matrix(np.array([(pose.R) for pose in pose_data])) | ||
|         
                  myeatman-bdai marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
|  | ||
| self.spline: Optional[SplineSE3] = None | ||
|  | ||
| def stochastic_downsample_interpolation( | ||
| self, | ||
| epsilon_xyz: float = 1e-3, | ||
| epsilon_angle: float = 1e-1, | ||
| normalize_time: bool = True, | ||
| bc_type: str = "not-a-knot", | ||
| ) -> Tuple[InterpSplineSE3, List[int]]: | ||
| """ | ||
| Uses a random dropout heuristic to downsample a trajectory with an interpolated spline. | ||
|  | ||
| This code does not ensure the global fit is within epsilon_xyz and epsilon_angle. | ||
|  | ||
| Return: | ||
| downsampled interpolating spline, | ||
| list of removed indices from input data | ||
| """ | ||
| spline = InterpSplineSE3( | ||
| self.time_data, | ||
| self.pose_data, | ||
| normalize_time=normalize_time, | ||
| bc_type=bc_type, | ||
| ) | ||
| chosen_indices: set[int] = set() | ||
|         
                  myeatman-bdai marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| interpolation_indices = list(range(len(self.pose_data))) | ||
| interpolation_indices.remove(0) | ||
| interpolation_indices.remove(len(self.pose_data) - 1) | ||
|          | ||
|  | ||
| for _ in range(len(self.time_data) - 2): # you must have at least 2 indices | ||
|         
                  myeatman-bdai marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved         
                  myeatman-bdai marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved          | ||
| choices = list(set(interpolation_indices).difference(chosen_indices)) | ||
|         
                  myeatman-bdai marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
|  | ||
| index = np.random.choice(choices) | ||
|  | ||
| chosen_indices.add(index) | ||
| interpolation_indices.remove(index) | ||
|  | ||
| spline.spline_xyz = CubicSpline( | ||
| self.time_data[interpolation_indices], | ||
| self.xyz_data[interpolation_indices], | ||
| ) | ||
| spline.spline_so3 = RotationSpline( | ||
| self.time_data[interpolation_indices], | ||
| self.so3_data[interpolation_indices], | ||
| ) | ||
|  | ||
| time = self.time_data[index] | ||
| angular_error = SO3(self.pose_data[index]).angdist( | ||
| SO3(spline.spline_so3(time).as_matrix()) | ||
| ) | ||
| euclidean_error = np.linalg.norm( | ||
| self.pose_data[index].t - spline.spline_xyz(time) | ||
| ) | ||
| if (angular_error > epsilon_angle) or (euclidean_error > epsilon_xyz): | ||
| interpolation_indices.insert( | ||
| int(np.searchsorted(interpolation_indices, index, side="right")), | ||
| index, | ||
| ) | ||
|  | ||
| self.spline = spline | ||
| return spline, interpolation_indices | ||
|  | ||
| def max_angular_error(self) -> float: | ||
| return np.max(self.angular_errors) | ||
|  | ||
| @cached_property | ||
| def angular_errors(self) -> List[float]: | ||
|         
                  myeatman-bdai marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| return [ | ||
| pose.angdist(self.spline(t)) | ||
| for pose, t in zip(self.pose_data, self.time_data) | ||
| ] | ||
|  | ||
| def max_euclidean_error(self) -> float: | ||
| return np.max(self.euclidean_errors) | ||
|  | ||
| @cached_property | ||
| def euclidean_errors(self) -> List[float]: | ||
| return [ | ||
| np.linalg.norm(pose.t - self.spline(t).t) | ||
| for pose, t in zip(self.pose_data, self.time_data) | ||
| ] | ||
|  | ||
|  | ||
| class BSplineSE3(SplineSE3): | ||
| """A class to parameterize a trajectory in SE3 with a 6-dimensional B-spline. | ||
|  | ||
| The SE3 control poses are converted to se3 twists (the lie algebra) and a B-spline | ||
|  | @@ -41,7 +265,7 @@ def __init__( | |
| at a given t input. If none, they are automatically, uniformly generated based on number of control poses and | ||
| degree of spline. | ||
| """ | ||
|  | ||
| super().__init__() | ||
| self.control_poses = control_poses | ||
|  | ||
| # a matrix where each row is a control pose as a twist | ||
|  | @@ -74,32 +298,3 @@ def __call__(self, t: float) -> SE3: | |
| """ | ||
| twist = np.hstack([spline(t) for spline in self.splines]) | ||
| return SE3.Exp(twist) | ||
|  | ||
| def visualize( | ||
| self, | ||
| num_samples: int, | ||
| length: float = 1.0, | ||
| repeat: bool = False, | ||
| ax: Optional[plt.Axes] = None, | ||
| kwargs_trplot: Dict[str, Any] = {"color": "green"}, | ||
| kwargs_tranimate: Dict[str, Any] = {"wait": True}, | ||
| kwargs_plot: Dict[str, Any] = {}, | ||
| ) -> None: | ||
| """Displays an animation of the trajectory with the control poses.""" | ||
| out_poses = [self(t) for t in np.linspace(0, 1, num_samples)] | ||
| x = [pose.x for pose in out_poses] | ||
| y = [pose.y for pose in out_poses] | ||
| z = [pose.z for pose in out_poses] | ||
|  | ||
| if ax is None: | ||
| fig = plt.figure(figsize=(10, 10)) | ||
| ax = fig.add_subplot(projection="3d") | ||
|  | ||
| trplot( | ||
| [np.array(self.control_poses)], ax=ax, length=length, **kwargs_trplot | ||
| ) # plot control points | ||
| ax.plot(x, y, z, **kwargs_plot) # plot x,y,z trajectory | ||
|  | ||
| tranimate( | ||
| out_poses, repeat=repeat, length=length, **kwargs_tranimate | ||
| ) # animate pose along trajectory | ||
Uh oh!
There was an error while loading. Please reload this page.