Skip to content

Commit 246469c

Browse files
committed
RF: Add TriMeshFamily to formalize 1-mesh-to-many-coords structure
1 parent 7d13e3b commit 246469c

File tree

2 files changed

+44
-55
lines changed

2 files changed

+44
-55
lines changed

nibabel/pointset.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,20 +145,32 @@ def get_coords(self, *, as_homogeneous: bool = False):
145145

146146

147147
class TriangularMesh(Pointset):
148+
def __init__(self, mesh):
149+
if isinstance(mesh, tuple) and len(mesh) == 2:
150+
coords, self._triangles = mesh
151+
elif hasattr(mesh, 'coords') and hasattr(mesh, 'triangles'):
152+
coords = mesh.coords
153+
self._triangles = mesh.triangles
154+
elif hasattr(mesh, 'get_mesh'):
155+
coords, self._triangles = mesh.get_mesh()
156+
else:
157+
raise ValueError('Cannot interpret input as triangular mesh')
158+
super().__init__(coords)
159+
148160
@property
149161
def n_triangles(self):
150162
"""Number of faces
151163
152164
Subclasses should override with more efficient implementations.
153165
"""
154-
return len(self.get_triangles())
166+
return self._triangles.shape[0]
155167

156-
def get_triangles(self, name=None):
168+
def get_triangles(self):
157169
"""Mx3 array of indices into coordinate table"""
158-
raise NotImplementedError
170+
return self._triangles
159171

160172
def get_mesh(self, name=None):
161-
return self.get_coords(name=name), self.get_triangles(name=name)
173+
return self.get_coords(name=name), self.get_triangles()
162174

163175
def get_names(self):
164176
"""List of surface names that can be passed to
@@ -175,6 +187,29 @@ def get_names(self):
175187
# raise NotImplementedError
176188

177189

190+
class TriMeshFamily(TriangularMesh):
191+
def __init__(self, mapping, default=None):
192+
self._triangles = None
193+
self._coords = {}
194+
for name, mesh in dict(mapping).items():
195+
coords, triangles = TriangularMesh(mesh).get_mesh()
196+
if self._triangles is None:
197+
self._triangles = triangles
198+
self._coords[name] = coords
199+
200+
if default is None:
201+
default = next(iter(self._coords))
202+
self._default = default
203+
204+
def get_names(self):
205+
return list(self._coords)
206+
207+
def get_coords(self, name=None):
208+
if name is None:
209+
name = self._default
210+
return self._coords[name]
211+
212+
178213
class Grid(Pointset):
179214
r"""A regularly-spaced collection of coordinates
180215

nibabel/tests/test_pointset.py

Lines changed: 5 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -218,14 +218,11 @@ def __getitem__(self, slicer):
218218
return h5f[self.dataset_name][slicer]
219219

220220

221-
class H5Geometry(ps.TriangularMesh):
221+
class H5Geometry(ps.TriMeshFamily):
222222
"""Simple Geometry file structure that combines a single topology
223223
with one or more coordinate sets
224224
"""
225225

226-
def __init__(self, meshes):
227-
self._meshes = meshes
228-
229226
@classmethod
230227
def from_filename(klass, pathlike):
231228
meshes = {}
@@ -236,33 +233,11 @@ def from_filename(klass, pathlike):
236233
return klass(meshes)
237234

238235
def to_filename(self, pathlike):
239-
topology = None
240-
coordinates = {}
241-
for name, mesh in self._meshes.items():
242-
coords, faces = mesh
243-
if topology is None:
244-
topology = faces
245-
elif not np.array_equal(faces, topology):
246-
raise ValueError('Inconsistent topology')
247-
coordinates[name] = coords
248-
249236
with h5.File(pathlike, 'w') as h5f:
250-
h5f.create_dataset('/topology', data=topology)
251-
for name, coord in coordinates.items():
237+
h5f.create_dataset('/topology', data=self.get_triangles())
238+
for name, coord in self._coords.items():
252239
h5f.create_dataset(f'/coordinates/{name}', data=coord)
253240

254-
def get_coords(self, name=None):
255-
if name is None:
256-
name = next(iter(self._meshes))
257-
coords, _ = self._meshes[name]
258-
return coords
259-
260-
def get_triangles(self, name=None):
261-
if name is None:
262-
name = next(iter(self._meshes))
263-
_, triangles = self._meshes[name]
264-
return triangles
265-
266241

267242
class FSGeometryProxy:
268243
def __init__(self, pathlike):
@@ -316,10 +291,7 @@ def triangles(self):
316291
return ap
317292

318293

319-
class FreeSurferHemisphere(ps.TriangularMesh):
320-
def __init__(self, meshes):
321-
self._meshes = meshes
322-
294+
class FreeSurferHemisphere(ps.TriMeshFamily):
323295
@classmethod
324296
def from_filename(klass, pathlike):
325297
path = Path(pathlike)
@@ -345,24 +317,6 @@ def from_filename(klass, pathlike):
345317
hemi._default = default
346318
return hemi
347319

348-
def get_coords(self, name=None):
349-
if name is None:
350-
name = self._default
351-
return self._meshes[name].coords
352-
353-
def get_triangles(self, name=None):
354-
if name is None:
355-
name = self._default
356-
return self._meshes[name].triangles
357-
358-
@property
359-
def n_coords(self):
360-
return self._meshes[self._default].vnum
361-
362-
@property
363-
def n_triangles(self):
364-
return self._meshes[self._default].fnum
365-
366320

367321
def test_FreeSurferHemisphere():
368322
lh = FreeSurferHemisphere.from_filename(FS_DATA / 'fsaverage/surf/lh.white')
@@ -377,6 +331,6 @@ def test_make_H5Geometry(tmp_path):
377331
h5geo.to_filename(tmp_path / 'geometry.h5')
378332

379333
rt_h5geo = H5Geometry.from_filename(tmp_path / 'geometry.h5')
380-
assert set(h5geo._meshes) == set(rt_h5geo._meshes)
334+
assert set(h5geo._coords) == set(rt_h5geo._coords)
381335
assert np.array_equal(lh.get_coords('white'), rt_h5geo.get_coords('white'))
382336
assert np.array_equal(lh.get_triangles(), rt_h5geo.get_triangles())

0 commit comments

Comments
 (0)