22from collections import OrderedDict
33from typing import Dict , List , Optional , Tuple , Union
44
5- import fast_simplification
65import numpy as np
76import torch
8- import trimesh
7+
8+ from roboreg .io import Mesh
99
1010
1111class TorchMeshContainer :
@@ -17,7 +17,7 @@ class TorchMeshContainer:
1717 """
1818
1919 __slots__ = [
20- "_mesh_names " ,
20+ "_names " ,
2121 "_vertices" , # tensor of shape (B, N, 4) -> homogeneous coordinates
2222 "_faces" , # tensor of shape (B, N, 3)
2323 "_per_mesh_vertex_count" ,
@@ -32,12 +32,11 @@ class TorchMeshContainer:
3232
3333 def __init__ (
3434 self ,
35- mesh_paths : Dict [str , str ],
35+ meshes : Dict [str , Mesh ],
3636 batch_size : int = 1 ,
3737 device : torch .device = "cuda" ,
38- target_reduction : float = 0.0 ,
3938 ) -> None :
40- self ._mesh_names = []
39+ self ._names = []
4140 self ._vertices = []
4241 self ._faces = []
4342 self ._per_mesh_vertex_count = OrderedDict ()
@@ -47,8 +46,8 @@ def __init__(
4746 self ._lower_face_index_lookup = {}
4847 self ._upper_face_index_lookup = {}
4948
50- # load meshes
51- self ._populate_meshes ( mesh_paths , device , target_reduction )
49+ # populate this container
50+ self ._populate_container ( meshes , device )
5251
5352 # add batch dim
5453 self ._batch_size = batch_size
@@ -60,42 +59,22 @@ def __init__(
6059 self ._device = device
6160
6261 @abc .abstractmethod
63- def _load_mesh (self , mesh_path : str ) -> trimesh .Trimesh :
64- return trimesh .load (mesh_path )
65-
66- @abc .abstractmethod
67- def _populate_meshes (
62+ def _populate_container (
6863 self ,
69- mesh_paths : Dict [str , str ],
64+ meshes : Dict [str , Mesh ],
7065 device : torch .device = "cuda" ,
71- target_reduction : float = 0.0 ,
7266 ) -> None :
7367 offset = 0
74- for mesh_name , mesh_path in mesh_paths .items ():
68+ for name , mesh in meshes .items ():
7569 # populate mesh names
76- self ._mesh_names .append (mesh_name )
77-
78- # load mesh
79- m = self ._load_mesh (mesh_path )
80-
81- if isinstance (m , trimesh .Scene ):
82- m = m .dump (concatenate = True )
83-
84- vertices = m .vertices
85- faces = m .faces
86-
87- vertices , faces = fast_simplification .simplify (
88- points = vertices ,
89- triangles = faces ,
90- target_reduction = target_reduction ,
91- )
70+ self ._names .append (name )
9271
9372 # populate mesh vertex count
94- self ._per_mesh_vertex_count [mesh_name ] = len (vertices )
73+ self ._per_mesh_vertex_count [name ] = len (mesh . vertices )
9574
9675 # populate vertices
9776 self ._vertices .append (
98- torch .tensor (vertices , dtype = torch .float32 , device = device )
77+ torch .tensor (mesh . vertices , dtype = torch .float32 , device = device )
9978 )
10079 self ._vertices [- 1 ] = torch .cat (
10180 [
@@ -106,22 +85,22 @@ def _populate_meshes(
10685 ) # (x,y,z) -> (x,y,z,1): homogeneous coordinates
10786
10887 # populate mesh face count
109- self ._per_mesh_face_count [mesh_name ] = len (faces )
88+ self ._per_mesh_face_count [name ] = len (mesh . faces )
11089
11190 # populate faces (also add an offset to the point ids)
11291 self ._faces .append (
11392 torch .add (
114- torch .tensor (faces , dtype = torch .int32 , device = device ),
93+ torch .tensor (mesh . faces , dtype = torch .int32 , device = device ),
11594 offset ,
11695 )
11796 )
118- offset += len (vertices )
97+ offset += len (mesh . vertices )
11998
12099 self ._vertices = torch .cat (self ._vertices , dim = 0 )
121100 self ._faces = torch .cat (self ._faces , dim = 0 )
122101
123102 def _populate_index_lookups (self ) -> None :
124- if len (self ._mesh_names ) == 0 :
103+ if len (self ._names ) == 0 :
125104 raise ValueError ("No meshes loaded." )
126105 if len (self ._per_mesh_vertex_count ) == 0 :
127106 raise ValueError ("No vertex counts populated." )
@@ -131,16 +110,16 @@ def _populate_index_lookups(self) -> None:
131110 # crucial: self._per_mesh_vertex_count sorted same as self._vertices! Same for faces.
132111 running_vertex_index = 0
133112 running_face_index = 0
134- for mesh_name in self ._mesh_names :
113+ for name in self ._names :
135114 # vertex index lookup
136- self ._lower_vertex_index_lookup [mesh_name ] = running_vertex_index
137- running_vertex_index += self ._per_mesh_vertex_count [mesh_name ]
138- self ._upper_vertex_index_lookup [mesh_name ] = running_vertex_index
115+ self ._lower_vertex_index_lookup [name ] = running_vertex_index
116+ running_vertex_index += self ._per_mesh_vertex_count [name ]
117+ self ._upper_vertex_index_lookup [name ] = running_vertex_index
139118
140119 # face index lookup
141- self ._lower_face_index_lookup [mesh_name ] = running_face_index
142- running_face_index += self ._per_mesh_face_count [mesh_name ]
143- self ._upper_face_index_lookup [mesh_name ] = running_face_index
120+ self ._lower_face_index_lookup [name ] = running_face_index
121+ running_face_index += self ._per_mesh_face_count [name ]
122+ self ._upper_face_index_lookup [name ] = running_face_index
144123
145124 @property
146125 def vertices (self ) -> torch .FloatTensor :
@@ -171,8 +150,8 @@ def upper_face_index_lookup(self) -> Dict[str, int]:
171150 return self ._upper_face_index_lookup
172151
173152 @property
174- def mesh_names (self ) -> List [str ]:
175- return self ._mesh_names
153+ def names (self ) -> List [str ]:
154+ return self ._names
176155
177156 @property
178157 def device (self ) -> torch .device :
0 commit comments