66from .structs import TorchMeshContainer
77
88
9- class Robot ( TorchMeshContainer ) :
10- __slots__ = ["_kinematics" , "_configured_vertices" ]
9+ class Robot :
10+ __slots__ = ["_mesh_container" , " _kinematics" , "_configured_vertices" , "_device " ]
1111
1212 def __init__ (
1313 self ,
14+ mesh_container : TorchMeshContainer ,
15+ kinematics : TorchKinematics ,
16+ device : torch .device = torch .device ("cuda" ),
17+ ) -> None :
18+ self ._mesh_container = mesh_container
19+ self ._kinematics = kinematics
20+ self ._configured_vertices = self .mesh_container .vertices .clone ()
21+ self ._device = device
22+ self .to (device = self ._device )
23+
24+ @classmethod
25+ def from_urdf_parser (
26+ cls ,
1427 urdf_parser : URDFParser ,
1528 root_link_name : str ,
1629 end_link_name : str ,
1730 collision : bool = False ,
1831 batch_size : int = 1 ,
19- device : torch .device = "cuda" ,
32+ device : torch .device = torch . device ( "cuda" ) ,
2033 target_reduction : float = 0.0 ,
21- ) -> None :
22- super ().__init__ (
23- mesh_paths = urdf_parser .ros_package_mesh_paths (
24- root_link_name = root_link_name ,
25- end_link_name = end_link_name ,
26- collision = collision ,
27- ),
34+ ) -> "Robot" :
35+ from roboreg .io import apply_mesh_origins , load_meshes , simplify_meshes
36+
37+ # parse data from URDF
38+ mesh_paths = urdf_parser .ros_package_mesh_paths (
39+ root_link_name = root_link_name ,
40+ end_link_name = end_link_name ,
41+ collision = collision ,
42+ )
43+ mesh_origins = urdf_parser .mesh_origins (
44+ root_link_name = root_link_name ,
45+ end_link_name = end_link_name ,
46+ collision = collision ,
47+ )
48+
49+ # load and preprocess meshes
50+ meshes = load_meshes (paths = mesh_paths )
51+ meshes = simplify_meshes (
52+ meshes = meshes ,
53+ target_reduction = target_reduction ,
54+ )
55+ meshes = apply_mesh_origins (meshes = meshes , origins = mesh_origins )
56+
57+ # configure this robot
58+ mesh_container = TorchMeshContainer (
59+ meshes = meshes ,
2860 batch_size = batch_size ,
2961 device = device ,
30- target_reduction = target_reduction ,
3162 )
32- self ._kinematics = TorchKinematics (
33- urdf_parser = urdf_parser ,
63+
64+ kinematics = TorchKinematics (
65+ urdf = urdf_parser .urdf ,
3466 root_link_name = root_link_name ,
3567 end_link_name = end_link_name ,
3668 device = device ,
3769 )
38- self ._configured_vertices = self .vertices .clone ()
70+
71+ return cls (mesh_container = mesh_container , kinematics = kinematics , device = device )
3972
4073 def configure (
4174 self , q : torch .FloatTensor , ht_root : torch .FloatTensor = None
@@ -44,37 +77,51 @@ def configure(
4477 raise ValueError (
4578 f"Expected joint states of shape { self ._kinematics .chain .n_joints } , got { q .shape [- 1 ]} ."
4679 )
47- if q .shape [0 ] != self ._batch_size :
80+ if q .shape [0 ] != self ._mesh_container . batch_size :
4881 raise ValueError (
49- f"Batch size mismatch. Meshes: { self ._batch_size } , joint states: { q .shape [0 ]} ."
82+ f"Batch size mismatch. Meshes: { self ._mesh_container . batch_size } , joint states: { q .shape [0 ]} ."
5083 )
5184 if ht_root is None :
5285 ht_root = torch .eye (4 , device = self ._device ).unsqueeze (0 )
53- ht_target_lookup = self ._kinematics .mesh_forward_kinematics (q )
54- self ._configured_vertices = self .vertices .clone ()
86+ ht_target_lookup = self ._kinematics .forward_kinematics (q )
87+ self ._configured_vertices = self .mesh_container . vertices .clone ()
5588 for link_name , ht in ht_target_lookup .items ():
5689 self ._configured_vertices [
5790 :,
58- self .lower_vertex_index_lookup [
91+ self .mesh_container . lower_vertex_index_lookup [
5992 link_name
60- ] : self .upper_vertex_index_lookup [link_name ],
93+ ] : self .mesh_container . upper_vertex_index_lookup [link_name ],
6194 ] = torch .matmul (
6295 torch .matmul (
6396 self ._configured_vertices [
6497 :,
65- self .lower_vertex_index_lookup [
98+ self .mesh_container . lower_vertex_index_lookup [
6699 link_name
67- ] : self .upper_vertex_index_lookup [link_name ],
100+ ] : self .mesh_container . upper_vertex_index_lookup [link_name ],
68101 ],
69102 ht .transpose (- 1 , - 2 ),
70103 ),
71104 ht_root .transpose (- 1 , - 2 ),
72105 )
73106
107+ def to (self , device : torch .device ) -> None :
108+ self ._mesh_container .to (device = device )
109+ self ._kinematics .to (device = device )
110+ self ._configured_vertices = self ._configured_vertices .to (device = device )
111+ self ._device = device
112+
74113 @property
75114 def kinematics (self ) -> TorchKinematics :
76115 return self ._kinematics
77116
117+ @property
118+ def mesh_container (self ) -> TorchMeshContainer :
119+ return self ._mesh_container
120+
78121 @property
79122 def configured_vertices (self ) -> torch .FloatTensor :
80123 return self ._configured_vertices
124+
125+ @property
126+ def device (self ) -> torch .device :
127+ return self ._device
0 commit comments