diff --git a/src/pytorch_kinematics/chain.py b/src/pytorch_kinematics/chain.py index d837567..e0900c1 100644 --- a/src/pytorch_kinematics/chain.py +++ b/src/pytorch_kinematics/chain.py @@ -1,5 +1,5 @@ from functools import lru_cache -from typing import Optional, Sequence +from typing import Collection, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -7,6 +7,7 @@ import pytorch_kinematics.transforms as tf from pytorch_kinematics import jacobian from pytorch_kinematics.frame import Frame, Link, Joint +from pytorch_kinematics.transforms.parameterized_transform import ParameterizedTransform from pytorch_kinematics.transforms.rotation_conversions import axis_and_angle_to_matrix_44, axis_and_d_to_pris_matrix @@ -189,7 +190,7 @@ def _find_link_recursive(name, frame) -> Optional[Link]: @staticmethod def _get_joints(frame, exclude_fixed=True): joints = [] - if exclude_fixed and frame.joint.joint_type != "fixed": + if not exclude_fixed or frame.joint.joint_type != "fixed": joints.append(frame.joint) for child in frame.children: joints.extend(Chain._get_joints(child)) @@ -293,12 +294,32 @@ def get_link_names(self): def get_frame_indices(self, *frame_names): return torch.tensor([self.frame_to_idx[n] for n in frame_names], dtype=torch.long, device=self.device) - def forward_kinematics(self, th, frame_indices: Optional = None): + def _get_jnt_transform(self, th) -> Tuple[torch.Tensor, torch.Tensor]: """ - Compute forward kinematics for the given joint values. + compute all joint transforms at once first in order to handle multiple joint types without branching, + we create all possible transforms for all joint types and then select the appropriate one for each joint. + Args: + th: The joint configuration to use + + Returns: A tuple of revolute, prismatic joint transforms + """ + axes_expanded = self.axes.unsqueeze(0).repeat(th.shape[0], 1, 1).to(th) + return axis_and_angle_to_matrix_44(axes_expanded, th), axis_and_d_to_pris_matrix(axes_expanded, th) + + def forward_kinematics(self, + th: Optional[torch.Tensor] = None, + joint_offsets: Optional[Union[torch.Tensor, tf.Transform3d]] = None, + link_offsets: Optional[Union[torch.Tensor, tf.Transform3d]] = None, + frame_indices: Optional = None): + """ + Compute forward kinematics for the given any combination of joint values, joint offsets, and link offsets. Args: th: A dict, list, numpy array, or torch tensor of joints values. Possibly batched. + joint_offsets: A Transform3d object or a tensor of shape (N, 4, 4) where N is the number of joints. + If provided, overrides the joint offsets in the chain. + link_offsets: A Transform3d object or a tensor of shape (N, 4, 4) where N is the number of joints. + If provided, overrides the link offsets in the chain. frame_indices: A list of frame indices to compute transforms for. If None, all frames are computed. Use `get_frame_indices` to convert from frame names to frame indices. @@ -306,36 +327,58 @@ def forward_kinematics(self, th, frame_indices: Optional = None): A dict of Transform3d objects for each frame. """ + def get_ith_transform(offset, i): + if isinstance(offset, torch.Tensor): + return offset[:, i, ...] + return offset[i] + if frame_indices is None: frame_indices = self.get_all_frame_indices() - th = self.ensure_tensor(th) - th = torch.atleast_2d(th) + if isinstance(joint_offsets, tf.Transform3d): + joint_offsets = joint_offsets.get_matrix().view(-1, len(self.joint_offsets), 4, 4) + if isinstance(link_offsets, tf.Transform3d): + link_offsets = link_offsets.get_matrix().view(-1, len(self.link_offsets), 4, 4) + + if th is joint_offsets is link_offsets is None: + raise ValueError("Must provide at least one of th, joint_offsets, or link_offsets.") + if th is not None: + th = self.ensure_tensor(th) + th = torch.atleast_2d(th) + b = th.shape[0] + to_this = th + elif joint_offsets is not None: + b = joint_offsets.shape[0] + to_this = joint_offsets + else: + b = link_offsets.shape[0] + to_this = link_offsets - b = th.shape[0] - axes_expanded = self.axes.unsqueeze(0).repeat(b, 1, 1) + # initialize default values + if th is None: + th = torch.zeros([b, self.n_joints]).to(to_this) + if joint_offsets is None: + joint_offsets = self.joint_offsets + if link_offsets is None: + link_offsets = self.link_offsets - # compute all joint transforms at once first - # in order to handle multiple joint types without branching, we create all possible transforms - # for all joint types and then select the appropriate one for each joint. - rev_jnt_transform = axis_and_angle_to_matrix_44(axes_expanded, th) - pris_jnt_transform = axis_and_d_to_pris_matrix(axes_expanded, th) + rev_jnt_transform, pris_jnt_transform = self._get_jnt_transform(th) frame_transforms = {} - b = th.shape[0] for frame_idx in frame_indices: - frame_transform = torch.eye(4).to(th).unsqueeze(0).repeat(b, 1, 1) + frame_transform = torch.eye(4).to(to_this).unsqueeze(0).repeat(b, 1, 1) # iterate down the list and compose the transform for chain_idx in self.parents_indices[frame_idx.item()]: if chain_idx.item() in frame_transforms: frame_transform = frame_transforms[chain_idx.item()] else: - link_offset_i = self.link_offsets[chain_idx] + link_offset_i = get_ith_transform(link_offsets, chain_idx) if link_offset_i is not None: frame_transform = frame_transform @ link_offset_i - joint_offset_i = self.joint_offsets[chain_idx] + joint_offset_i = get_ith_transform(joint_offsets, chain_idx) + if joint_offset_i is not None: frame_transform = frame_transform @ joint_offset_i @@ -458,35 +501,93 @@ def _generate_serial_chain_recurse(root_frame, end_frame_name): return [child] + frames return None + @classmethod + def from_joint_transforms(cls, + transforms: tf.Transform3d, + link_offsets: Optional[tf.Transform3d] = None, + joint_names: Optional[Collection[str]] = None, + link_names: Optional[Collection[str]] = None, + joint_types: Optional[Collection[str]] = None, + **kwargs + ): + """ + Create a serial chain with zero link offsets and joint offsets according to the input. + + Assumes that frame 0 is the root frame that contains an empty link and a fixed "world" joint. Accordingly, frame + i is aligned with joint i, which moves. All joint axes are aligned with the z-axis of the joint frame. + Args: + transforms: A transform that represents a matrix of shape (N, 4, 4) where N is the number of joints. + link_offsets: A transform that represents a matrix of shape (N, 4, 4) where N is the number of joints. + Optional. If None, all link offsets are assumed to be zero. + joint_names: The names of the joints. If None, the joints are named "joint_0", "joint_1", etc. + link_names: The names of the links. If None, the links are named "link_0", "link_1", etc. + joint_types: The types of the joints. If None, the joints are assumed to be revolute. + """ + device = kwargs.get('device', transforms.device) + dtype = kwargs.get('dtype', transforms.dtype) + + transforms = transforms.to(device=device, dtype=dtype) + joint_offsets = transforms.get_matrix() + assert len(joint_offsets.shape) == 3, "Expected a 3D matrix of shape (N, 4, 4)." + n = joint_offsets.shape[0] + if link_offsets is None: + link_offsets = [None] * n + else: + link_offsets = link_offsets.to(device=device, dtype=dtype) + if joint_names is None: + joint_names = [f"joint_{i+1}" for i in range(n)] + if link_names is None: + link_names = [f"link_{i+1}" for i in range(n)] + if joint_types is None: + joint_types = ['revolute' for _ in range(n)] + root_frame = Frame(name="world") + root_frame.link = Link(name="world") + root_frame.joint = Joint(name="world", joint_type="fixed") + children = [] + for (i, link, joint, joint_type) in reversed(list(zip(range(n), link_names, joint_names, joint_types))): + frame = Frame(name=f"{link}") + frame.link = Link(name=link, offset=link_offsets[i]) + frame.joint = Joint(name=joint, offset=transforms[i], joint_type=joint_type) + frame.children = children + children = [frame] + root_frame.children = children + return cls(Chain(root_frame, **kwargs), link_names[-1], root_frame_name="world") + def jacobian(self, th, locations=None): if locations is not None: locations = tf.Transform3d(pos=locations) return jacobian.calc_jacobian(self, th, tool=locations) - def forward_kinematics(self, th, end_only: bool = True): + def forward_kinematics(self, + th: Optional[torch.Tensor] = None, + joint_offsets: Optional[torch.Tensor] = None, + link_offsets: Optional[torch.Tensor] = None, + end_only: bool = True): """ Like the base class, except `th` only needs to contain the joints in the SerialChain, not all joints. """ - frame_indices, th = self.convert_serial_inputs_to_chain_inputs(th, end_only) + if end_only: + frame_indices = self.get_frame_indices(self._serial_frames[-1].name) + else: + # pass through default behavior for frame indices being None, which is currently + # to return all frames. + frame_indices = None + if th is not None: + th = self.convert_serial_inputs_to_chain_inputs(th) - mat = super().forward_kinematics(th, frame_indices) + mat = super().forward_kinematics(th, joint_offsets=joint_offsets, link_offsets=link_offsets, + frame_indices=frame_indices) if end_only: return mat[self._serial_frames[-1].name] else: return mat - def convert_serial_inputs_to_chain_inputs(self, th, end_only: bool): + def convert_serial_inputs_to_chain_inputs(self, th: torch.Tensor): # th = self.ensure_tensor(th) th_b = get_batch_size(th) th_n_joints = get_n_joints(th) if isinstance(th, list): th = torch.tensor(th, device=self.device, dtype=self.dtype) - if end_only: - frame_indices = self.get_frame_indices(self._serial_frames[-1].name) - else: - # pass through default behavior for frame indices being None, which is currently - # to return all frames. - frame_indices = None if th_n_joints < self.n_joints: # if th is only a partial list of joints, assume it's a list of joints for only the serial chain. partial_th = th @@ -504,4 +605,4 @@ def convert_serial_inputs_to_chain_inputs(self, th, end_only: bool): jnt_idx = self.joint_indices[k] if frame.joint.joint_type != 'fixed': th[..., jnt_idx] = partial_th_i - return frame_indices, th + return th diff --git a/src/pytorch_kinematics/transforms/parameter_conversions.py b/src/pytorch_kinematics/transforms/parameter_conversions.py new file mode 100644 index 0000000..fb0f163 --- /dev/null +++ b/src/pytorch_kinematics/transforms/parameter_conversions.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Author: Jonathan Külz +# Date: 23.11.23 +import torch + + +def mdh_to_homogeneous(mdh_parameters: torch.Tensor) -> torch.Tensor: + """ + Converts a set of MDH parameters to a homogeneous transformation matrix. + + Follows Craig, Introduction to Robotics, 2005, p. 75. + :param mdh_parameters: The MDH parameters, ordered as alpha, a, d, theta. + :return: The homogeneous transformation matrix. + """ + alpha = mdh_parameters[..., 0] + a = mdh_parameters[..., 1] + d = mdh_parameters[..., 2] + theta = mdh_parameters[..., 3] + + ct = torch.cos(theta) + st = torch.sin(theta) + ca = torch.cos(alpha) + sa = torch.sin(alpha) + zeros = torch.zeros_like(theta) + return torch.stack([ + torch.stack([ct, -st, zeros, a], dim=-1), + torch.stack([st * ca, ct * ca, -sa, -d * sa], dim=-1), + torch.stack([st * sa, ct * sa, ca, d * ca], dim=-1), + torch.stack([zeros, zeros, zeros, torch.ones_like(theta)], + dim=-1) + ], dim=-2) + + +def homogeneous_to_mdh(T: torch.Tensor) -> torch.Tensor: + """ + Converts a homogeneous transformation matrix to a set of MDH parameters. + + Attention, this method is expensive due to an internal sanity check. + Follows Craig, Introduction to Robotics, 2005, p. 75. + :param T: The homogeneous transformation matrix. + :return: The MDH parameters. + """ + a = T[..., 0, 3] + theta = torch.atan2(-T[..., 0, 1], T[..., 0, 0]) + alpha = torch.atan2(-T[..., 1, 2], T[..., 2, 2]) + d = torch.empty_like(a) + use_cos = torch.isclose(torch.sin(alpha), torch.zeros_like(alpha)) + d[~use_cos] = -T[~use_cos][:, 1, 3] / torch.sin(alpha[~use_cos]) + d[use_cos] = T[use_cos][:, 2, 3] / torch.cos(alpha[use_cos]) + + parameters = torch.stack([alpha, a, d, theta], dim=-1) + if not torch.allclose(mdh_to_homogeneous(parameters), T, atol=1e-3): + raise ValueError('The given transformation is not MDH.') + + return parameters diff --git a/src/pytorch_kinematics/transforms/parameterized_transform.py b/src/pytorch_kinematics/transforms/parameterized_transform.py new file mode 100644 index 0000000..5caae67 --- /dev/null +++ b/src/pytorch_kinematics/transforms/parameterized_transform.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from enum import Enum +from functools import lru_cache +from typing import Collection, Iterable, Optional, Tuple, Union + +import numpy as np +import torch + +from .transform3d import Transform3d +from .parameter_conversions import mdh_to_homogeneous, homogeneous_to_mdh + + +class ParameterConvention(Enum): + """A parameter convention for a kinematic chain.""" + + MDH = 1 # Modified Denavit-Hartenberg convention after Craig + + +class ParameterizedTransform(Transform3d, ABC): + """ + A 3d transform which is made from a set of (differentiable) parameters. + + The ParameterizedTransform class supports two levels of batching: The first level is the batch dimension of the + parameters for a single robot, the second level is the batch dimension of multiple robots. + """ + + convention: ParameterConvention # Implement this in subclasses + parameter_names: Tuple[str] # Implement this in subclasses + + def __init__( + self, + parameters: Optional[torch.Tensor] = None, + dtype: torch.dtype = torch.float32, + device: str = 'cpu', + requires_grad: bool = True, + matrix: Optional[torch.Tensor] = None, + default_batch_size: Union[Tuple[int], Tuple[int, int]] = (1, 1) + ): + """Initialize a ParameterizedTransform.""" + super().__init__(dtype=dtype, device=device, matrix=matrix) + assert len(default_batch_size) in (1, 2), "default_batch_size must be a tuple of length 1 or 2" + + if parameters is None: + parameters = torch.zeros(*default_batch_size, self.get_num_parameters()) + while parameters.ndim < 1 + len(default_batch_size): + parameters = parameters.unsqueeze(0) + + self.default_batch_size: Union[Tuple[int], Tuple[int, int]] = default_batch_size + self.requires_grad: bool = requires_grad + self.parameters: torch.Tensor = parameters.to(self.device, self.dtype) + self._matrix = None + + @abstractmethod + def get_matrix(self) -> torch.Tensor: + """Returns the matrix representation of the transform.""" + + @abstractmethod + def update_joint_parameters(self, th: torch.Tensor, joint_types: np.array): + """Updates the parameters of the parameters according to joint configuration th.""" + + def clone(self) -> ParameterizedTransform: + """ + Deep copy of ParameterizedTransform object. All internal tensors are cloned individually. + + Returns: + new ParameterizedTransform object. + """ + other = self.__class__(dtype=self.dtype, device=self.device, requires_grad=self.requires_grad) + other._matrix = self.get_matrix() + other.parameters = self._parameters.detach().clone() + if self._lu is not None: + other._lu = [elem.clone() for elem in self._lu] + return other + + def stack(self, *others, dim=0): + """ + Stacks multiple ParameterizedTransform objects together. + + Args: + others: ParameterizedTransform objects to stack. + dim: Dimension along which to stack. Can be 0 for batch and 1 for joints. + + Returns: + new ParameterizedTransform object. + """ + other = self.__class__(dtype=self.dtype, device=self.device, requires_grad=self.requires_grad) + transforms = [self] + list(others) + parameters = torch.cat([t._parameters for t in transforms], dim=dim).to(self.device, dtype=self.dtype) + other._parameters = parameters + return other + + def to(self, device, copy: bool = False, dtype=None): + """Makes sure to also set parameters to the correct device.""" + other = super().to(device, copy, dtype) + if other is self and other._parameters.device == device and (dtype is None or dtype == other._parameters.dtype): + return self + other.parameters = self._parameters.detach().to(device, dtype) + return other + + def toTransform3d(self): + """Returns a Transform3d object with the same matrix as this ParameterizedTransform.""" + return Transform3d(matrix=self.get_matrix(), dtype=self.dtype, device=self.device) + + @property + def num_batch_levels(self) -> int: + """Returns the number of batch levels.""" + return len(self.default_batch_size) + + @property + def parameters(self) -> torch.Tensor: + """Returns the joint parameters""" + return self._parameters + + @parameters.setter + def parameters(self, parameters: torch.Tensor): + """Sets the joint parameters""" + if parameters.requires_grad != self.requires_grad: + # This check allows to set non-leaf parameters and is necessary for example for __getitem__ + parameters.requires_grad = self.requires_grad + self._parameters = parameters + + @classmethod + def get_num_parameters(cls) -> int: + """Returns the number of parameters""" + return len(cls.parameter_names) + + def __getitem__(self, item): + if isinstance(item, Iterable): + item = [i if not isinstance(i, int) else slice(i, i + 1) for i in item] # Make sure to not lose dimensions + return self.__class__(parameters=self.parameters[item], dtype=self.dtype, device=self.device, + default_batch_size=self.default_batch_size) + + def __repr__(self) -> str: + """Returns a string representation of the transform.""" + info = ', '.join([f'{name}={self.parameters[..., i]}' for i, name in enumerate(self.parameter_names)]) + return f"{self.__class__}({info})".replace('\n ', '') + + +class MDHTransform(ParameterizedTransform): + """A transformation derived from Denavit-Hartenberg parameters.""" + + convention: ParameterConvention = ParameterConvention.MDH + parameter_names: Tuple[str, str, str, str] = ('alpha', 'a', 'd', 'theta') + + @property + def theta(self): + """Returns the joint angle.""" + return self.parameters[..., 3] + + @property + def d(self): + """Returns the joint offset.""" + return self.parameters[..., 2] + + @property + def a(self): + """Returns the link length.""" + return self.parameters[..., 1] + + @property + def alpha(self): + """Returns the link twist.""" + return self.parameters[..., 0] + + def get_matrix(self) -> torch.Tensor: + """Returns the matrix representation of the transform. Redos the computation on every call""" + b = self.parameters.shape[0] + if self.num_batch_levels == 1: + self._matrix = mdh_to_homogeneous(self.parameters).view(b, 4, 4) + else: + self._matrix = mdh_to_homogeneous(self.parameters).view(b, -1, 4, 4) + return self._matrix + + def update_joint_parameters(self, th: torch.Tensor, joint_types: np.array): + """Updates the parameters of the parameters according to joint configuration th.""" + is_revolute = joint_types == 'revolute' + is_prismatic = joint_types == 'prismatic' + assert np.all(np.logical_xor(is_revolute, is_prismatic)) + self.parameters[is_revolute.nonzero()][:, 3] = th[is_revolute.nonzero()] + self.parameters[is_prismatic.nonzero()][:, 2] = th[is_prismatic.nonzero()] + + @classmethod + def from_homogeneous(cls, + homogeneous: torch.Tensor, + dtype: torch.dtype = torch.float32, + device: str = 'cpu', + requires_grad: bool = True, + ): + """Creates a MDHTransform from a homogeneous transformation matrix.""" + parameters = homogeneous_to_mdh(homogeneous) + return cls(parameters, dtype, device, requires_grad) + + +CONVENTION_IMPLEMENTATIONS = { + ParameterConvention.MDH: MDHTransform, +} diff --git a/src/pytorch_kinematics/transforms/transform3d.py b/src/pytorch_kinematics/transforms/transform3d.py index 1651f0c..2f71cd9 100644 --- a/src/pytorch_kinematics/transforms/transform3d.py +++ b/src/pytorch_kinematics/transforms/transform3d.py @@ -293,8 +293,9 @@ def inverse(self, invert_composed: bool = False): def stack(self, *others): transforms = [self] + list(others) - matrix = torch.cat([t._matrix for t in transforms], dim=0) - out = Transform3d(matrix=matrix, device=self.device, dtype=self.dtype) + matrix = torch.cat([t._matrix for t in transforms], dim=0).to(self.device, self.dtype) + out = self.__class__(device=self.device, dtype=self.dtype) + out._matrix = matrix return out def transform_points(self, points, eps: Optional[float] = None): diff --git a/src/pytorch_kinematics/visualize.py b/src/pytorch_kinematics/visualize.py new file mode 100644 index 0000000..564937c --- /dev/null +++ b/src/pytorch_kinematics/visualize.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# Author: Jonathan Külz +# Date: 27.11.23 +from typing import Collection, Optional + +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import axes3d +import numpy as np +import torch + +from pytorch_kinematics import Transform3d + + +def visualize(frames: Transform3d, goal: Optional[torch.Tensor] = None, show: bool = False, **kwargs) -> plt.Axes: + """ + Visualizes a sequence of frames. + Args: + frames: The forward kinematics transforms of a single robot. + goal: The goal to visualize. + show: Whether to show the plot. + + Returns: None + """ + frames = np.vstack([f.get_matrix().cpu().detach().numpy() for f in frames]) + axis_size = np.max(frames[:, :3, 3]) - np.min(frames[:, :3, 3]) * .8 + center = np.mean(frames[:, :3, 3], axis=0) + frame_scale = axis_size / 15 + ax = plt.figure(figsize=(12, 12)).add_subplot(projection='3d') + frames = frames.reshape(-1, 4, 4) + num_frames = frames.shape[0] + draw_base(ax, scale=frame_scale, **kwargs) + if goal is not None: + draw_goal(ax, goal, scale=frame_scale, **kwargs) + for i, frame in enumerate(frames): + draw_frame(ax, frame, scale=frame_scale, **kwargs) + if i < num_frames - 1: + draw_link(ax, frame[:3, 3], frames[i + 1][:3, 3], **kwargs) + ax.set_xlim(center[0] - axis_size, center[0] + axis_size) + ax.set_ylim(center[1] - axis_size, center[1] + axis_size) + ax.set_zlim(center[2] - axis_size, center[2] + axis_size) + ax.set_xlabel('x') + ax.set_ylabel('y') + ax.set_zlabel('z') + if show: + plt.show() + return ax + + +def draw_base(ax: axes3d.Axes3D, scale: float = .1, **kwargs): + """Draw a sphere of radius scale at the origin.""" + u, v = np.mgrid[0:2 * np.pi:20j, 0:np.pi:10j] + x = scale * np.cos(u) * np.sin(v) + y = scale * np.sin(u) * np.sin(v) + z = scale * np.cos(v) + ax.plot_wireframe(x, y, z, color='gray') + + +def draw_goal(ax: axes3d.Axes3D, goal: torch.Tensor, scale: float = .1, **kwargs): + """Draw a sphere of radius scale at the origin.""" + u, v = np.mgrid[0:2 * np.pi:10j, 0:np.pi:5j] + goal = np.squeeze(goal.cpu().detach().numpy()) + x = goal[0] + scale * np.cos(u) * np.sin(v) + y = goal[1] + scale * np.sin(u) * np.sin(v) + z = goal[2] + scale * np.cos(v) + ax.plot_wireframe(x, y, z, color='red') + + +def draw_frame(ax: axes3d.Axes3D, frame: np.ndarray, scale: float = .1): + """ + Draws a frame. + Args: + ax: The axis to draw on. + frame: The frame placement. + scale: The length of the frame axes. + + Returns: None + """ + origin = frame[:3, 3] + x = frame[:3, 0] + y = frame[:3, 1] + z = frame[:3, 2] + ax.quiver(origin[0], origin[1], origin[2], x[0], x[1], x[2], length=scale, color='r', linewidths=3) + ax.quiver(origin[0], origin[1], origin[2], y[0], y[1], y[2], length=scale, color='g', linewidths=3) + ax.quiver(origin[0], origin[1], origin[2], z[0], z[1], z[2], length=scale, color='b', linewidths=3) + + +def draw_link(ax: axes3d.Axes3D, p0: np.ndarray, p1: np.ndarray, **kwargs): + """ + Draws a line from p0 to p1. + Args: + ax: The axis to draw on. + p0: The start point. + p1: The end point. + + Returns: None + """ + kwargs.setdefault('color', 'black') + ax.plot([p0[0], p1[0]], [p0[1], p1[1]], [p0[2], p1[2]], linewidth=3, **kwargs)