Skip to content

Commit 95a268a

Browse files
committed
removed: origins from kinematics, inheritance and ROS from Robot
1 parent c8ff9ff commit 95a268a

File tree

13 files changed

+199
-97
lines changed

13 files changed

+199
-97
lines changed

cli/rr_cam_swarm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77

88
from roboreg import differentiable as rrd
9-
from roboreg.io import find_files, parse_camera_info, parse_mono_data
9+
from roboreg.io import URDFParser, find_files, parse_camera_info, parse_mono_data
1010
from roboreg.losses import soft_dice_loss
1111
from roboreg.optim import LinearParticleSwarm, ParticleSwarmOptimizer
1212
from roboreg.util import (
@@ -301,9 +301,9 @@ def main() -> None:
301301
device=device,
302302
)
303303

304-
urdf_parser = rrd.URDFParser()
304+
urdf_parser = URDFParser()
305305
urdf_parser.from_ros_xacro(ros_package=args.ros_package, xacro_path=args.xacro_path)
306-
robot = rrd.Robot(
306+
robot = rrd.Robot.from_urdf_parser(
307307
urdf_parser=urdf_parser,
308308
root_link_name=args.root_link_name,
309309
end_link_name=args.end_link_name,

cli/rr_hydra.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def main():
183183

184184
# instantiate robot
185185
batch_size = len(joint_states)
186-
robot = Robot(
186+
robot = Robot.from_urdf_parser(
187187
urdf_parser=urdf_parser,
188188
root_link_name=root_link_name,
189189
end_link_name=end_link_name,

roboreg/differentiable/kinematics.py

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,43 +3,31 @@
33
import pytorch_kinematics as pk
44
import torch
55

6-
from roboreg.io import URDFParser
7-
86

97
class TorchKinematics:
108
__slots__ = [
119
"_root_link_name",
1210
"_end_link_name",
1311
"_chain",
14-
"_mesh_origins_lookup",
1512
"_device",
1613
]
1714

1815
def __init__(
1916
self,
20-
urdf_parser: URDFParser,
17+
urdf: str,
2118
root_link_name: str,
2219
end_link_name: str,
23-
device: torch.device = "cuda",
20+
device: torch.device = torch.device("cuda"),
2421
) -> None:
2522
self._root_link_name = root_link_name
2623
self._end_link_name = end_link_name
2724
self._chain = self._build_serial_chain_from_urdf(
28-
urdf_parser.urdf,
25+
urdf=urdf,
2926
root_link_name=self._root_link_name,
3027
end_link_name=self._end_link_name,
3128
)
32-
33-
self._mesh_origins_lookup = urdf_parser.mesh_origins(
34-
root_link_name=root_link_name, end_link_name=end_link_name
35-
)
36-
self._mesh_origins_lookup = {
37-
key: torch.from_numpy(value).to(device=device, dtype=torch.float32)
38-
for key, value in self._mesh_origins_lookup.items()
39-
}
40-
41-
# default move to device
42-
self.to(device=device)
29+
self._device = device
30+
self.to(device=self._device)
4331

4432
def _build_serial_chain_from_urdf(
4533
self, urdf: str, root_link_name: str, end_link_name: str
@@ -50,18 +38,11 @@ def _build_serial_chain_from_urdf(
5038

5139
def to(self, device: torch.device) -> None:
5240
self._chain.to(device=device)
53-
for link_name in self._mesh_origins_lookup:
54-
self._mesh_origins_lookup[link_name] = self._mesh_origins_lookup[
55-
link_name
56-
].to(device=device)
5741
self._device = device
5842

59-
def mesh_forward_kinematics(self, q: torch.Tensor) -> Dict[str, torch.Tensor]:
60-
r"""Computes forward kinematics and returns corresponding homogeneous transformations.
61-
Corrects for mesh offsets. Meshes that are tranformed by the returned transformation appear physically correct.
62-
"""
43+
def forward_kinematics(self, q: torch.Tensor) -> Dict[str, torch.Tensor]:
6344
ht_lookup = {
64-
key: value.get_matrix() @ self._mesh_origins_lookup[key]
45+
key: value.get_matrix()
6546
for key, value in self._chain.forward_kinematics(q, end_only=False).items()
6647
}
6748
return ht_lookup

roboreg/differentiable/rendering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class NVDiffRastRenderer:
1313

1414
__slots__ = ["_device", "_ctx"]
1515

16-
def __init__(self, device: torch.device = "cuda") -> None:
16+
def __init__(self, device: torch.device = torch.device("cuda")) -> None:
1717
super().__init__()
1818
if not torch.cuda.is_available():
1919
raise ValueError("CUDA is not available.")

roboreg/differentiable/robot.py

Lines changed: 69 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,69 @@
66
from .structs import TorchMeshContainer
77

88

9-
class Robot(TorchMeshContainer):
10-
__slots__ = ["_kinematics", "_configured_vertices"]
9+
class Robot:
10+
__slots__ = ["_mesh_container", "_kinematics", "_configured_vertices", "_device"]
1111

1212
def __init__(
1313
self,
14+
mesh_container: TorchMeshContainer,
15+
kinematics: TorchKinematics,
16+
device: torch.device = torch.device("cuda"),
17+
) -> None:
18+
self._mesh_container = mesh_container
19+
self._kinematics = kinematics
20+
self._configured_vertices = self.mesh_container.vertices.clone()
21+
self._device = device
22+
self.to(device=self._device)
23+
24+
@classmethod
25+
def from_urdf_parser(
26+
cls,
1427
urdf_parser: URDFParser,
1528
root_link_name: str,
1629
end_link_name: str,
1730
collision: bool = False,
1831
batch_size: int = 1,
19-
device: torch.device = "cuda",
32+
device: torch.device = torch.device("cuda"),
2033
target_reduction: float = 0.0,
21-
) -> None:
22-
super().__init__(
23-
mesh_paths=urdf_parser.ros_package_mesh_paths(
24-
root_link_name=root_link_name,
25-
end_link_name=end_link_name,
26-
collision=collision,
27-
),
34+
) -> "Robot":
35+
from roboreg.io import apply_mesh_origins, load_meshes, simplify_meshes
36+
37+
# parse data from URDF
38+
mesh_paths = urdf_parser.ros_package_mesh_paths(
39+
root_link_name=root_link_name,
40+
end_link_name=end_link_name,
41+
collision=collision,
42+
)
43+
mesh_origins = urdf_parser.mesh_origins(
44+
root_link_name=root_link_name,
45+
end_link_name=end_link_name,
46+
collision=collision,
47+
)
48+
49+
# load and preprocess meshes
50+
meshes = load_meshes(paths=mesh_paths)
51+
meshes = simplify_meshes(
52+
meshes=meshes,
53+
target_reduction=target_reduction,
54+
)
55+
meshes = apply_mesh_origins(meshes=meshes, origins=mesh_origins)
56+
57+
# configure this robot
58+
mesh_container = TorchMeshContainer(
59+
meshes=meshes,
2860
batch_size=batch_size,
2961
device=device,
30-
target_reduction=target_reduction,
3162
)
32-
self._kinematics = TorchKinematics(
33-
urdf_parser=urdf_parser,
63+
64+
kinematics = TorchKinematics(
65+
urdf=urdf_parser.urdf,
3466
root_link_name=root_link_name,
3567
end_link_name=end_link_name,
3668
device=device,
3769
)
38-
self._configured_vertices = self.vertices.clone()
70+
71+
return cls(mesh_container=mesh_container, kinematics=kinematics, device=device)
3972

4073
def configure(
4174
self, q: torch.FloatTensor, ht_root: torch.FloatTensor = None
@@ -44,37 +77,51 @@ def configure(
4477
raise ValueError(
4578
f"Expected joint states of shape {self._kinematics.chain.n_joints}, got {q.shape[-1]}."
4679
)
47-
if q.shape[0] != self._batch_size:
80+
if q.shape[0] != self._mesh_container.batch_size:
4881
raise ValueError(
49-
f"Batch size mismatch. Meshes: {self._batch_size}, joint states: {q.shape[0]}."
82+
f"Batch size mismatch. Meshes: {self._mesh_container.batch_size}, joint states: {q.shape[0]}."
5083
)
5184
if ht_root is None:
5285
ht_root = torch.eye(4, device=self._device).unsqueeze(0)
53-
ht_target_lookup = self._kinematics.mesh_forward_kinematics(q)
54-
self._configured_vertices = self.vertices.clone()
86+
ht_target_lookup = self._kinematics.forward_kinematics(q)
87+
self._configured_vertices = self.mesh_container.vertices.clone()
5588
for link_name, ht in ht_target_lookup.items():
5689
self._configured_vertices[
5790
:,
58-
self.lower_vertex_index_lookup[
91+
self.mesh_container.lower_vertex_index_lookup[
5992
link_name
60-
] : self.upper_vertex_index_lookup[link_name],
93+
] : self.mesh_container.upper_vertex_index_lookup[link_name],
6194
] = torch.matmul(
6295
torch.matmul(
6396
self._configured_vertices[
6497
:,
65-
self.lower_vertex_index_lookup[
98+
self.mesh_container.lower_vertex_index_lookup[
6699
link_name
67-
] : self.upper_vertex_index_lookup[link_name],
100+
] : self.mesh_container.upper_vertex_index_lookup[link_name],
68101
],
69102
ht.transpose(-1, -2),
70103
),
71104
ht_root.transpose(-1, -2),
72105
)
73106

107+
def to(self, device: torch.device) -> None:
108+
self._mesh_container.to(device=device)
109+
self._kinematics.to(device=device)
110+
self._configured_vertices = self._configured_vertices.to(device=device)
111+
self._device = device
112+
74113
@property
75114
def kinematics(self) -> TorchKinematics:
76115
return self._kinematics
77116

117+
@property
118+
def mesh_container(self) -> TorchMeshContainer:
119+
return self._mesh_container
120+
78121
@property
79122
def configured_vertices(self) -> torch.FloatTensor:
80123
return self._configured_vertices
124+
125+
@property
126+
def device(self) -> torch.device:
127+
return self._device

roboreg/differentiable/scene.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def observe_from(
6666
)
6767
return self._renderer.constant_color(
6868
observed_vertices,
69-
self._robot.faces,
69+
self._robot.mesh_container.faces,
7070
self._cameras[camera_name].resolution,
7171
)
7272

roboreg/differentiable/structs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(
3434
self,
3535
meshes: Dict[str, Mesh],
3636
batch_size: int = 1,
37-
device: torch.device = "cuda",
37+
device: torch.device = torch.device("cuda"),
3838
) -> None:
3939
self._names = []
4040
self._vertices = []
@@ -62,7 +62,7 @@ def __init__(
6262
def _populate_container(
6363
self,
6464
meshes: Dict[str, Mesh],
65-
device: torch.device = "cuda",
65+
device: torch.device("cuda"),
6666
) -> None:
6767
offset = 0
6868
for name, mesh in meshes.items():
@@ -184,7 +184,7 @@ def __init__(
184184
resolution: Tuple[int, int],
185185
intrinsics: Optional[Union[torch.FloatTensor, np.ndarray]] = None,
186186
extrinsics: Optional[Union[torch.FloatTensor, np.ndarray]] = None,
187-
device: torch.device = "cuda",
187+
device: torch.device = torch.device("cuda"),
188188
name: str = "camera",
189189
) -> None:
190190
if intrinsics is None:
@@ -312,7 +312,7 @@ def __init__(
312312
extrinsics: Optional[Union[torch.FloatTensor, np.ndarray]] = None,
313313
zmin: float = 0.1,
314314
zmax: float = 100.0,
315-
device: torch.device = "cuda",
315+
device: torch.device = torch.device("cuda"),
316316
) -> None:
317317
super().__init__(resolution, intrinsics, extrinsics, device)
318318

roboreg/io/meshes.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def simplify_mesh(mesh: Mesh, target_reduction: float = 0.0) -> Mesh:
6464
Returns:
6565
Mesh: The simplified mesh.
6666
"""
67+
if target_reduction == 0.0:
68+
return mesh
6769
if 0.0 > target_reduction > 1.0:
6870
raise ValueError(
6971
f"Expected target reduction in [0, 1], got {target_reduction}."
@@ -88,7 +90,41 @@ def simplify_meshes(
8890
Returns:
8991
Dict[str, Mesh]: The simplified meshes.
9092
"""
93+
if target_reduction == 0.0:
94+
return meshes
9195
return {
9296
name: simplify_mesh(mesh, target_reduction=target_reduction)
9397
for name, mesh in meshes.items()
9498
}
99+
100+
101+
def apply_mesh_origin(vertices: np.ndarray, origin: np.ndarray) -> np.ndarray:
102+
r"""Apply a homogeneous transformation to mesh vertices.
103+
104+
Args:
105+
vertices (np.ndarray): The mesh vertices of shape Nx3.
106+
origin (np.ndarray): The mesh origin as a homogeneous transform of shape 4x4.
107+
Returns:
108+
np.ndarray: The transformed mesh vertices of shape Nx3.
109+
"""
110+
return vertices @ origin[:3, :3].T + origin[:3, 3].T
111+
112+
113+
def apply_mesh_origins(
114+
meshes: Dict[str, Mesh], origins: Dict[str, np.ndarray]
115+
) -> Dict[str, Mesh]:
116+
r"""Apply mesh origins to multiple meshes.
117+
118+
Args:
119+
meshes (Dict[str, Mesh]): The meshes to apply origins to.
120+
origins (Dict[str, np.ndarray]): The mesh origins.
121+
122+
Returns:
123+
Dict[str, Mesh]: The meshes with applied origins.
124+
"""
125+
for name in meshes.keys():
126+
if name in origins:
127+
meshes[name].vertices = apply_mesh_origin(
128+
meshes[name].vertices, origins[name]
129+
)
130+
return meshes

roboreg/util/factories.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def create_robot_scene(
5656
)
5757

5858
# instantiate robot
59-
robot = Robot(
59+
robot = Robot.from_urdf_parser(
6060
urdf_parser=urdf_parser,
6161
root_link_name=root_link_name,
6262
end_link_name=end_link_name,

0 commit comments

Comments
 (0)