55import rich
66import torch
77
8- from roboreg .differentiable import TorchKinematics , TorchMeshContainer
8+ from roboreg .differentiable import Robot
99from roboreg .hydra_icp import hydra_centroid_alignment , hydra_robust_icp
1010from roboreg .io import URDFParser , parse_camera_info , parse_hydra_data
1111from roboreg .util import (
@@ -178,46 +178,22 @@ def main():
178178 rich .print (
179179 f"End link name not provided. Using the last link with mesh: '{ end_link_name } '."
180180 )
181- kinematics = TorchKinematics (
181+
182+ # instantiate robot
183+ batch_size = len (joint_states )
184+ robot = Robot (
182185 urdf_parser = urdf_parser ,
183- device = device ,
184186 root_link_name = root_link_name ,
185187 end_link_name = end_link_name ,
186- )
187-
188- # instantiate mesh
189- batch_size = len (joint_states )
190- meshes = TorchMeshContainer (
191- mesh_paths = urdf_parser .ros_package_mesh_paths (
192- root_link_name = root_link_name ,
193- end_link_name = end_link_name ,
194- visual = args .visual_meshes ,
195- ),
188+ visual = args .visual_meshes ,
196189 batch_size = batch_size ,
197- device = device ,
198190 )
199191
200192 # perform forward kinematics
201- mesh_vertices = meshes .vertices .clone ()
202193 joint_states = torch .tensor (
203194 np .array (joint_states ), dtype = torch .float32 , device = device
204195 )
205- ht_lookup = kinematics .mesh_forward_kinematics (joint_states )
206- for link_name , ht in ht_lookup .items ():
207- mesh_vertices [
208- :,
209- meshes .lower_vertex_index_lookup [
210- link_name
211- ] : meshes .upper_vertex_index_lookup [link_name ],
212- ] = torch .matmul (
213- mesh_vertices [
214- :,
215- meshes .lower_vertex_index_lookup [
216- link_name
217- ] : meshes .upper_vertex_index_lookup [link_name ],
218- ],
219- ht .transpose (- 1 , - 2 ),
220- )
196+ robot .configure (joint_states )
221197
222198 # turn depths into xyzs
223199 intrinsics = torch .tensor (intrinsics , dtype = torch .float32 , device = device )
@@ -242,12 +218,12 @@ def main():
242218 xyzs = [xyz .squeeze () for xyz in xyzs .cpu ().numpy ()]
243219
244220 # mesh vertices to list
245- mesh_vertices = from_homogeneous (mesh_vertices )
221+ mesh_vertices = from_homogeneous (robot . configured_vertices )
246222 mesh_vertices = [mesh_vertices [i ].contiguous () for i in range (batch_size )]
247223 mesh_normals = []
248224 for i in range (batch_size ):
249225 mesh_normals .append (
250- compute_vertex_normals (vertices = mesh_vertices [i ], faces = meshes .faces )
226+ compute_vertex_normals (vertices = mesh_vertices [i ], faces = robot .faces )
251227 )
252228
253229 # clean observed vertices and turn into tensor
0 commit comments