Skip to content

Commit c8ff9ff

Browse files
committed
removed processing from TorchMeshContainer
1 parent fb00595 commit c8ff9ff

File tree

9 files changed

+275
-152
lines changed

9 files changed

+275
-152
lines changed

roboreg/differentiable/structs.py

Lines changed: 26 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from collections import OrderedDict
33
from typing import Dict, List, Optional, Tuple, Union
44

5-
import fast_simplification
65
import numpy as np
76
import torch
8-
import trimesh
7+
8+
from roboreg.io import Mesh
99

1010

1111
class TorchMeshContainer:
@@ -17,7 +17,7 @@ class TorchMeshContainer:
1717
"""
1818

1919
__slots__ = [
20-
"_mesh_names",
20+
"_names",
2121
"_vertices", # tensor of shape (B, N, 4) -> homogeneous coordinates
2222
"_faces", # tensor of shape (B, N, 3)
2323
"_per_mesh_vertex_count",
@@ -32,12 +32,11 @@ class TorchMeshContainer:
3232

3333
def __init__(
3434
self,
35-
mesh_paths: Dict[str, str],
35+
meshes: Dict[str, Mesh],
3636
batch_size: int = 1,
3737
device: torch.device = "cuda",
38-
target_reduction: float = 0.0,
3938
) -> None:
40-
self._mesh_names = []
39+
self._names = []
4140
self._vertices = []
4241
self._faces = []
4342
self._per_mesh_vertex_count = OrderedDict()
@@ -47,8 +46,8 @@ def __init__(
4746
self._lower_face_index_lookup = {}
4847
self._upper_face_index_lookup = {}
4948

50-
# load meshes
51-
self._populate_meshes(mesh_paths, device, target_reduction)
49+
# populate this container
50+
self._populate_container(meshes, device)
5251

5352
# add batch dim
5453
self._batch_size = batch_size
@@ -60,42 +59,22 @@ def __init__(
6059
self._device = device
6160

6261
@abc.abstractmethod
63-
def _load_mesh(self, mesh_path: str) -> trimesh.Trimesh:
64-
return trimesh.load(mesh_path)
65-
66-
@abc.abstractmethod
67-
def _populate_meshes(
62+
def _populate_container(
6863
self,
69-
mesh_paths: Dict[str, str],
64+
meshes: Dict[str, Mesh],
7065
device: torch.device = "cuda",
71-
target_reduction: float = 0.0,
7266
) -> None:
7367
offset = 0
74-
for mesh_name, mesh_path in mesh_paths.items():
68+
for name, mesh in meshes.items():
7569
# populate mesh names
76-
self._mesh_names.append(mesh_name)
77-
78-
# load mesh
79-
m = self._load_mesh(mesh_path)
80-
81-
if isinstance(m, trimesh.Scene):
82-
m = m.dump(concatenate=True)
83-
84-
vertices = m.vertices
85-
faces = m.faces
86-
87-
vertices, faces = fast_simplification.simplify(
88-
points=vertices,
89-
triangles=faces,
90-
target_reduction=target_reduction,
91-
)
70+
self._names.append(name)
9271

9372
# populate mesh vertex count
94-
self._per_mesh_vertex_count[mesh_name] = len(vertices)
73+
self._per_mesh_vertex_count[name] = len(mesh.vertices)
9574

9675
# populate vertices
9776
self._vertices.append(
98-
torch.tensor(vertices, dtype=torch.float32, device=device)
77+
torch.tensor(mesh.vertices, dtype=torch.float32, device=device)
9978
)
10079
self._vertices[-1] = torch.cat(
10180
[
@@ -106,22 +85,22 @@ def _populate_meshes(
10685
) # (x,y,z) -> (x,y,z,1): homogeneous coordinates
10786

10887
# populate mesh face count
109-
self._per_mesh_face_count[mesh_name] = len(faces)
88+
self._per_mesh_face_count[name] = len(mesh.faces)
11089

11190
# populate faces (also add an offset to the point ids)
11291
self._faces.append(
11392
torch.add(
114-
torch.tensor(faces, dtype=torch.int32, device=device),
93+
torch.tensor(mesh.faces, dtype=torch.int32, device=device),
11594
offset,
11695
)
11796
)
118-
offset += len(vertices)
97+
offset += len(mesh.vertices)
11998

12099
self._vertices = torch.cat(self._vertices, dim=0)
121100
self._faces = torch.cat(self._faces, dim=0)
122101

123102
def _populate_index_lookups(self) -> None:
124-
if len(self._mesh_names) == 0:
103+
if len(self._names) == 0:
125104
raise ValueError("No meshes loaded.")
126105
if len(self._per_mesh_vertex_count) == 0:
127106
raise ValueError("No vertex counts populated.")
@@ -131,16 +110,16 @@ def _populate_index_lookups(self) -> None:
131110
# crucial: self._per_mesh_vertex_count sorted same as self._vertices! Same for faces.
132111
running_vertex_index = 0
133112
running_face_index = 0
134-
for mesh_name in self._mesh_names:
113+
for name in self._names:
135114
# vertex index lookup
136-
self._lower_vertex_index_lookup[mesh_name] = running_vertex_index
137-
running_vertex_index += self._per_mesh_vertex_count[mesh_name]
138-
self._upper_vertex_index_lookup[mesh_name] = running_vertex_index
115+
self._lower_vertex_index_lookup[name] = running_vertex_index
116+
running_vertex_index += self._per_mesh_vertex_count[name]
117+
self._upper_vertex_index_lookup[name] = running_vertex_index
139118

140119
# face index lookup
141-
self._lower_face_index_lookup[mesh_name] = running_face_index
142-
running_face_index += self._per_mesh_face_count[mesh_name]
143-
self._upper_face_index_lookup[mesh_name] = running_face_index
120+
self._lower_face_index_lookup[name] = running_face_index
121+
running_face_index += self._per_mesh_face_count[name]
122+
self._upper_face_index_lookup[name] = running_face_index
144123

145124
@property
146125
def vertices(self) -> torch.FloatTensor:
@@ -171,8 +150,8 @@ def upper_face_index_lookup(self) -> Dict[str, int]:
171150
return self._upper_face_index_lookup
172151

173152
@property
174-
def mesh_names(self) -> List[str]:
175-
return self._mesh_names
153+
def names(self) -> List[str]:
154+
return self._names
176155

177156
@property
178157
def device(self) -> torch.device:

roboreg/io/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .datasets import *
22
from .filesystem import *
3+
from .meshes import *
34
from .parsers import *

roboreg/io/meshes.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from dataclasses import dataclass
2+
from pathlib import Path
3+
from typing import Dict, Union
4+
5+
import fast_simplification
6+
import numpy as np
7+
import trimesh
8+
9+
10+
@dataclass
11+
class Mesh:
12+
r"""Dataclass to hold mesh data."""
13+
14+
vertices: np.ndarray
15+
faces: np.ndarray
16+
17+
18+
def load_mesh(path: Union[Path, str]) -> Mesh:
19+
r"""Reads a mesh file and returns vertices and faces.
20+
21+
Args:
22+
path (Union[Path, str]): Path to the mesh file.
23+
24+
Returns:
25+
Mesh:
26+
- Vertices of shape Nx3.
27+
- Faces of shape Nx3.
28+
"""
29+
path = Path(path)
30+
if not path.exists():
31+
raise FileNotFoundError(f"Mesh file {path} does not exist.")
32+
m = trimesh.load(path)
33+
if isinstance(m, trimesh.Scene):
34+
m = m.dump(concatenate=True)
35+
vertices, faces = m.vertices, m.faces
36+
if vertices.size == 0 or faces.size == 0:
37+
raise ValueError(f"Mesh is empty: {path.name}")
38+
return Mesh(vertices=vertices.view(np.ndarray), faces=faces.view(np.ndarray))
39+
40+
41+
def load_meshes(
42+
paths: Dict[str, Union[Path, str]],
43+
) -> Dict[str, Mesh]:
44+
r"""Load multiple meshes.
45+
46+
Args:
47+
paths (Dict[str, Union[Path, str]]): Mesh names and corresponding paths.
48+
49+
Returns:
50+
Dict[str, Mesh]:
51+
- Mesh name.
52+
- Mesh vertices of shape Nx3 and faces of shape Nx3.
53+
"""
54+
return {name: load_mesh(path) for name, path in paths.items()}
55+
56+
57+
def simplify_mesh(mesh: Mesh, target_reduction: float = 0.0) -> Mesh:
58+
r"""Simplify a mesh.
59+
60+
Args:
61+
mesh (Mesh): The mesh to be simplified.
62+
target_reduction (float): Target reduction in [0, 1]. Zero for no reduction.
63+
64+
Returns:
65+
Mesh: The simplified mesh.
66+
"""
67+
if 0.0 > target_reduction > 1.0:
68+
raise ValueError(
69+
f"Expected target reduction in [0, 1], got {target_reduction}."
70+
)
71+
vertices, faces = fast_simplification.simplify(
72+
points=mesh.vertices,
73+
triangles=mesh.faces,
74+
target_reduction=target_reduction,
75+
)
76+
return Mesh(vertices=vertices, faces=faces)
77+
78+
79+
def simplify_meshes(
80+
meshes: Dict[str, Mesh], target_reduction: float = 0.0
81+
) -> Dict[str, Mesh]:
82+
f"""Simplify multiple meshes.
83+
84+
Args:
85+
meshes (Dict[str, Mesh]): The meshes to be simplified.
86+
target_reduction (float): The target reduction in [0, 1]. Zero for no reduction.
87+
88+
Returns:
89+
Dict[str, Mesh]: The simplified meshes.
90+
"""
91+
return {
92+
name: simplify_mesh(mesh, target_reduction=target_reduction)
93+
for name, mesh in meshes.items()
94+
}

roboreg/util/sampling.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Tuple
22

3-
import numpy as np
43
import torch
54

65

test/differentiable/test_kinematics.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
import os
2-
import sys
3-
4-
sys.path.append(
5-
os.path.dirname((os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
6-
)
7-
81
import cv2
92
import numpy as np
103
import pytest
@@ -15,7 +8,7 @@
158
from roboreg.differentiable.kinematics import TorchKinematics
169
from roboreg.differentiable.rendering import NVDiffRastRenderer
1710
from roboreg.differentiable.structs import TorchMeshContainer
18-
from roboreg.io import URDFParser
11+
from roboreg.io import URDFParser, load_meshes
1912
from roboreg.util import from_homogeneous
2013

2114

@@ -60,7 +53,10 @@ def test_torch_kinematics_on_mesh(
6053
device=device,
6154
)
6255
meshes = TorchMeshContainer(
63-
urdf_parser.ros_package_mesh_paths(root_link_name, end_link_name), device=device
56+
meshes=load_meshes(
57+
urdf_parser.ros_package_mesh_paths(root_link_name, end_link_name)
58+
),
59+
device=device,
6460
)
6561

6662
# compute forward kinematics and apply transforms to the meshes
@@ -112,7 +108,10 @@ def test_diff_kinematics() -> None:
112108
device=device,
113109
)
114110
meshes = TorchMeshContainer(
115-
urdf_parser.ros_package_mesh_paths("lbr_link_0", "lbr_link_7"), device=device
111+
meshes=load_meshes(
112+
urdf_parser.ros_package_mesh_paths("lbr_link_0", "lbr_link_7")
113+
),
114+
device=device,
116115
)
117116

118117
# compute forward kinematics and apply transforms to the meshes
@@ -220,22 +219,28 @@ def test_diff_kinematics() -> None:
220219

221220

222221
if __name__ == "__main__":
223-
# test_torch_kinematics(
224-
# ros_package="lbr_description",
225-
# xacro_path="urdf/med7/med7.xacro",
226-
# root_link_name="lbr_link_0",
227-
# end_link_name="lbr_link_7",
228-
# )
229-
test_torch_kinematics_on_mesh(
222+
import os
223+
import sys
224+
225+
os.environ["QT_QPA_PLATFORM"] = "offscreen"
226+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
227+
228+
test_torch_kinematics(
230229
ros_package="lbr_description",
231230
xacro_path="urdf/med7/med7.xacro",
232231
root_link_name="lbr_link_0",
233232
end_link_name="lbr_link_7",
234233
)
235234
test_torch_kinematics_on_mesh(
236-
ros_package="xarm_description",
237-
xacro_path="urdf/xarm_device.urdf.xacro",
238-
root_link_name="link_base",
239-
end_link_name="link7",
235+
ros_package="lbr_description",
236+
xacro_path="urdf/med7/med7.xacro",
237+
root_link_name="lbr_link_0",
238+
end_link_name="lbr_link_7",
240239
)
240+
# test_torch_kinematics_on_mesh(
241+
# ros_package="xarm_description",
242+
# xacro_path="urdf/xarm_device.urdf.xacro",
243+
# root_link_name="link_base",
244+
# end_link_name="link7",
245+
# )
241246
# test_diff_kinematics()

0 commit comments

Comments
 (0)