Skip to content

Commit 0ac2a0a

Browse files
author
Peter
committed
all tests pass!
1 parent dbd2d02 commit 0ac2a0a

File tree

3 files changed

+69
-12
lines changed

3 files changed

+69
-12
lines changed

src/pytorch_kinematics/chain.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,16 @@ def ensure_2d_tensor(th, dtype, device):
3838
return th, N
3939

4040

41+
def get_dict_elem_shape(th_dict):
42+
elem = th_dict[list(th_dict.keys())[0]]
43+
if isinstance(elem, np.ndarray):
44+
return elem.shape
45+
elif isinstance(elem, torch.Tensor):
46+
return elem.shape
47+
else:
48+
return ()
49+
50+
4151
class Chain:
4252
"""
4353
Robot model that may be constructed from different descriptions via their respective parsers.
@@ -174,6 +184,7 @@ def get_joints(self, exclude_fixed=True):
174184
joints = self._get_joints(self._root, exclude_fixed=exclude_fixed)
175185
return joints
176186

187+
@lru_cache()
177188
def get_joint_parameter_names(self, exclude_fixed=True):
178189
names = []
179190
for f in self.get_joints(exclude_fixed=exclude_fixed):
@@ -270,11 +281,16 @@ def get_frame_indices(self, *frame_names):
270281

271282
def forward_kinematics(self, th, frame_indices: Optional = None):
272283
"""
273-
Instead of a tree, we can use a flat data structure with indexes to represent the parent
274-
then instead of recursion we can just iterate in order and use parent pointers. This
275-
reduces function call overhead and moves some of the indexing work to the constructor.
284+
Compute forward kinematics.
285+
286+
Args:
287+
th: A dict, list, numpy array, or torch tensor of ALL joints values. Possibly batched.
288+
the fastest thing to use is a torch tensor, all other types get converted to that.
289+
If any joint values are missing, an exception will be thrown.
290+
frame_indices: A list of frame indices to compute forward kinematics for. If None, all frames are computed.
276291
277-
use `get_frame_indices` to get the indices for the frames you want to compute the pose for.
292+
Returns:
293+
A dict of frame names and their corresponding Transform3d objects.
278294
"""
279295
if frame_indices is None:
280296
frame_indices = self.get_all_frame_indices()
@@ -343,18 +359,28 @@ def forward_kinematics(self, th, frame_indices: Optional = None):
343359
return frame_names_and_transform3ds
344360

345361
def ensure_tensor(self, th):
362+
"""
363+
Converts a number of possible types into a tensor. The order of the tensor is determined by the order
364+
of self.get_joint_parameter_names().
365+
"""
346366
if isinstance(th, np.ndarray):
347367
th = torch.tensor(th, device=self.device, dtype=self.dtype)
348-
if isinstance(th, list):
368+
elif isinstance(th, list):
349369
th = torch.tensor(th, device=self.device, dtype=self.dtype)
350-
if isinstance(th, dict):
370+
elif isinstance(th, dict):
351371
# convert dict to a flat, complete, tensor of all joints values. Missing joints are filled with zeros.
352372
th_dict = th
353-
th = torch.zeros(self.n_joints, device=self.device, dtype=self.dtype)
373+
elem_shape = get_dict_elem_shape(th_dict)
374+
th = torch.ones([*elem_shape, self.n_joints], device=self.device, dtype=self.dtype) * torch.nan
354375
joint_names = self.get_joint_parameter_names()
355376
for joint_name, joint_position in th_dict.items():
356377
jnt_idx = joint_names.index(joint_name)
357-
th[jnt_idx] = joint_position
378+
th[..., jnt_idx] = joint_position
379+
if torch.any(torch.isnan(th)):
380+
msg = "Missing values for the following joints:\n"
381+
for joint_name, th_i in zip(self.get_joint_parameter_names(), th):
382+
msg += joint_name + "\n"
383+
raise ValueError(msg)
358384
return th
359385

360386
def get_all_frame_indices(self):
@@ -454,6 +480,7 @@ def jacobian(self, th, locations=None):
454480
return jacobian.calc_jacobian(self, th, tool=locations)
455481

456482
def forward_kinematics(self, th, end_only: bool = True):
483+
""" Like the base class, except `th` only needs to contain the joints in the SerialChain, not all joints. """
457484
if end_only:
458485
frame_indices = self.get_frame_indices(self._serial_frames[-1].name)
459486
else:

src/pytorch_kinematics/sdf.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,9 @@ def build_chain_from_sdf(data):
100100
_convert_visuals(root_link.visuals))
101101
root_frame.children = _build_chain_recurse(root_frame, lmap, joints)
102102
return chain.Chain(root_frame)
103+
104+
105+
def build_serial_chain_from_sdf(data, end_link_name, root_link_name=""):
106+
mjcf_chain = build_chain_from_sdf(data)
107+
serial_chain = chain.SerialChain(mjcf_chain, end_link_name, "" if root_link_name == "" else root_link_name)
108+
return serial_chain

tests/test_kinematics.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,16 @@ def test_fk_mjcf():
3636
chain = chain.to(dtype=torch.float64)
3737
print(chain)
3838
print(chain.get_joint_parameter_names())
39-
th = {'hip_1': 1.0, 'ankle_1': 1}
39+
th = {
40+
'hip_1': 1.0,
41+
'ankle_1': 1,
42+
'hip_2': 0.0,
43+
'ankle_2': 0.0,
44+
'hip_3': 0.0,
45+
'ankle_3': 0.0,
46+
'hip_4': 0.0,
47+
'ankle_4': 0.0,
48+
}
4049
ret = chain.forward_kinematics(th)
4150
tg = ret['aux_1']
4251
pos, rot = quat_pos_from_transform3d(tg)
@@ -151,14 +160,24 @@ def test_fk_simple_arm():
151160
chain = chain.to(dtype=torch.float64)
152161
# print(chain)
153162
# print(chain.get_joint_parameter_names())
154-
ret = chain.forward_kinematics({'arm_elbow_pan_joint': math.pi / 2.0, 'arm_wrist_lift_joint': -0.5})
163+
ret = chain.forward_kinematics({
164+
'arm_shoulder_pan_joint': 0.,
165+
'arm_elbow_pan_joint': math.pi / 2.0,
166+
'arm_wrist_lift_joint': -0.5,
167+
'arm_wrist_roll_joint': 0.,
168+
})
155169
tg = ret['arm_wrist_roll']
156170
pos, rot = quat_pos_from_transform3d(tg)
157171
assert quaternion_equality(rot, torch.tensor([0.70710678, 0., 0., 0.70710678], dtype=torch.float64))
158172
assert torch.allclose(pos, torch.tensor([1.05, 0.55, 0.5], dtype=torch.float64))
159173

160174
N = 100
161-
ret = chain.forward_kinematics({'arm_elbow_pan_joint': torch.rand(N, 1), 'arm_wrist_lift_joint': torch.rand(N, 1)})
175+
ret = chain.forward_kinematics({
176+
'arm_shoulder_pan_joint': torch.rand(N),
177+
'arm_elbow_pan_joint': torch.rand(N),
178+
'arm_wrist_lift_joint': torch.rand(N),
179+
'arm_wrist_roll_joint': torch.rand(N),
180+
})
162181
tg = ret['arm_wrist_roll']
163182
assert list(tg.get_matrix().shape) == [N, 4, 4]
164183

@@ -176,7 +195,12 @@ def test_cuda():
176195
chain = pk.build_chain_from_sdf(open(os.path.join(TEST_DIR, "simple_arm.sdf")).read())
177196
chain = chain.to(dtype=dtype, device=d)
178197

179-
ret = chain.forward_kinematics({'arm_elbow_pan_joint': math.pi / 2.0, 'arm_wrist_lift_joint': -0.5})
198+
ret = chain.forward_kinematics({
199+
'arm_shoulder_pan_joint': 0,
200+
'arm_elbow_pan_joint': math.pi / 2.0,
201+
'arm_wrist_lift_joint': -0.5,
202+
'arm_wrist_roll_joint': 0,
203+
})
180204
tg = ret['arm_wrist_roll']
181205
pos, rot = quat_pos_from_transform3d(tg)
182206
assert quaternion_equality(rot, torch.tensor([0.70710678, 0., 0., 0.70710678], dtype=dtype, device=d))

0 commit comments

Comments
 (0)