1- from typing import Dict
1+ from typing import Dict , Union
22
33import pytorch_kinematics as pk
44import torch
55
6- from roboreg .io import URDFParser
7-
86
97class TorchKinematics :
108 __slots__ = [
119 "_root_link_name" ,
1210 "_end_link_name" ,
1311 "_chain" ,
14- "_mesh_origins_lookup" ,
1512 "_device" ,
1613 ]
1714
1815 def __init__ (
1916 self ,
20- urdf_parser : URDFParser ,
17+ urdf : str ,
2118 root_link_name : str ,
2219 end_link_name : str ,
23- device : torch .device = "cuda" ,
20+ device : Union [ torch .device , str ] = "cuda" ,
2421 ) -> None :
2522 self ._root_link_name = root_link_name
2623 self ._end_link_name = end_link_name
2724 self ._chain = self ._build_serial_chain_from_urdf (
28- urdf_parser . urdf ,
25+ urdf = urdf ,
2926 root_link_name = self ._root_link_name ,
3027 end_link_name = self ._end_link_name ,
3128 )
32-
33- self ._mesh_origins_lookup = urdf_parser .mesh_origins (
34- root_link_name = root_link_name , end_link_name = end_link_name
35- )
36- self ._mesh_origins_lookup = {
37- key : torch .from_numpy (value ).to (device = device , dtype = torch .float32 )
38- for key , value in self ._mesh_origins_lookup .items ()
39- }
40-
41- # default move to device
42- self .to (device = device )
29+ self ._device = torch .device (device ) if isinstance (device , str ) else device
30+ self .to (device = self ._device )
4331
4432 def _build_serial_chain_from_urdf (
4533 self , urdf : str , root_link_name : str , end_link_name : str
@@ -48,20 +36,13 @@ def _build_serial_chain_from_urdf(
4836 urdf , end_link_name = end_link_name , root_link_name = root_link_name
4937 )
5038
51- def to (self , device : torch .device ) -> None :
39+ def to (self , device : Union [ torch .device , str ] ) -> None :
5240 self ._chain .to (device = device )
53- for link_name in self ._mesh_origins_lookup :
54- self ._mesh_origins_lookup [link_name ] = self ._mesh_origins_lookup [
55- link_name
56- ].to (device = device )
57- self ._device = device
41+ self ._device = torch .device (device ) if isinstance (device , str ) else device
5842
59- def mesh_forward_kinematics (self , q : torch .Tensor ) -> Dict [str , torch .Tensor ]:
60- r"""Computes forward kinematics and returns corresponding homogeneous transformations.
61- Corrects for mesh offsets. Meshes that are tranformed by the returned transformation appear physically correct.
62- """
43+ def forward_kinematics (self , q : torch .Tensor ) -> Dict [str , torch .Tensor ]:
6344 ht_lookup = {
64- key : value .get_matrix () @ self . _mesh_origins_lookup [ key ]
45+ key : value .get_matrix ()
6546 for key , value in self ._chain .forward_kinematics (q , end_only = False ).items ()
6647 }
6748 return ht_lookup
0 commit comments