Skip to content

Commit c526d2e

Browse files
committed
RF: Add TriMeshFamily to formalize 1-mesh-to-many-coords structure
1 parent 0d0a6db commit c526d2e

File tree

2 files changed

+49
-57
lines changed

2 files changed

+49
-57
lines changed

nibabel/pointset.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,49 @@
77

88

99
class Pointset:
10+
def __init__(self, coords):
11+
self._coords = coords
12+
1013
@property
1114
def n_coords(self):
1215
"""Number of coordinates
1316
1417
Subclasses should override with more efficient implementations.
1518
"""
16-
return len(self.get_coords())
19+
return self.get_coords().shape[0]
1720

1821
def get_coords(self, name=None):
1922
"""Nx3 array of coordinates in RAS+ space"""
20-
raise NotImplementedError
23+
return self._coords
2124

2225

2326
class TriangularMesh(Pointset):
27+
def __init__(self, mesh):
28+
if isinstance(mesh, tuple) and len(mesh) == 2:
29+
coords, self._triangles = mesh
30+
elif hasattr(mesh, 'coords') and hasattr(mesh, 'triangles'):
31+
coords = mesh.coords
32+
self._triangles = mesh.triangles
33+
elif hasattr(mesh, 'get_mesh'):
34+
coords, self._triangles = mesh.get_mesh()
35+
else:
36+
raise ValueError('Cannot interpret input as triangular mesh')
37+
super().__init__(coords)
38+
2439
@property
2540
def n_triangles(self):
2641
"""Number of faces
2742
2843
Subclasses should override with more efficient implementations.
2944
"""
30-
return len(self.get_triangles())
45+
return self._triangles.shape[0]
3146

32-
def get_triangles(self, name=None):
47+
def get_triangles(self):
3348
"""Mx3 array of indices into coordinate table"""
34-
raise NotImplementedError
49+
return self._triangles
3550

3651
def get_mesh(self, name=None):
37-
return self.get_coords(name=name), self.get_triangles(name=name)
52+
return self.get_coords(name=name), self.get_triangles()
3853

3954
def get_names(self):
4055
"""List of surface names that can be passed to
@@ -51,6 +66,29 @@ def get_names(self):
5166
# raise NotImplementedError
5267

5368

69+
class TriMeshFamily(TriangularMesh):
70+
def __init__(self, mapping, default=None):
71+
self._triangles = None
72+
self._coords = {}
73+
for name, mesh in dict(mapping).items():
74+
coords, triangles = TriangularMesh(mesh).get_mesh()
75+
if self._triangles is None:
76+
self._triangles = triangles
77+
self._coords[name] = coords
78+
79+
if default is None:
80+
default = next(iter(self._coords))
81+
self._default = default
82+
83+
def get_names(self):
84+
return list(self._coords)
85+
86+
def get_coords(self, name=None):
87+
if name is None:
88+
name = self._default
89+
return self._coords[name]
90+
91+
5492
class NdGrid(Pointset):
5593
"""
5694
Attributes

nibabel/tests/test_pointset.py

Lines changed: 5 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,11 @@ def __getitem__(self, slicer):
4747
return h5f[self.dataset_name][slicer]
4848

4949

50-
class H5Geometry(ps.TriangularMesh):
50+
class H5Geometry(ps.TriMeshFamily):
5151
"""Simple Geometry file structure that combines a single topology
5252
with one or more coordinate sets
5353
"""
5454

55-
def __init__(self, meshes):
56-
self._meshes = meshes
57-
5855
@classmethod
5956
def from_filename(klass, pathlike):
6057
meshes = {}
@@ -65,33 +62,11 @@ def from_filename(klass, pathlike):
6562
return klass(meshes)
6663

6764
def to_filename(self, pathlike):
68-
topology = None
69-
coordinates = {}
70-
for name, mesh in self._meshes.items():
71-
coords, faces = mesh
72-
if topology is None:
73-
topology = faces
74-
elif not np.array_equal(faces, topology):
75-
raise ValueError('Inconsistent topology')
76-
coordinates[name] = coords
77-
7865
with h5.File(pathlike, 'w') as h5f:
79-
h5f.create_dataset('/topology', data=topology)
80-
for name, coord in coordinates.items():
66+
h5f.create_dataset('/topology', data=self.get_triangles())
67+
for name, coord in self._coords.items():
8168
h5f.create_dataset(f'/coordinates/{name}', data=coord)
8269

83-
def get_coords(self, name=None):
84-
if name is None:
85-
name = next(iter(self._meshes))
86-
coords, _ = self._meshes[name]
87-
return coords
88-
89-
def get_triangles(self, name=None):
90-
if name is None:
91-
name = next(iter(self._meshes))
92-
_, triangles = self._meshes[name]
93-
return triangles
94-
9570

9671
class FSGeometryProxy:
9772
def __init__(self, pathlike):
@@ -145,10 +120,7 @@ def triangles(self):
145120
return ap
146121

147122

148-
class FreeSurferHemisphere(ps.TriangularMesh):
149-
def __init__(self, meshes):
150-
self._meshes = meshes
151-
123+
class FreeSurferHemisphere(ps.TriMeshFamily):
152124
@classmethod
153125
def from_filename(klass, pathlike):
154126
path = Path(pathlike)
@@ -174,24 +146,6 @@ def from_filename(klass, pathlike):
174146
hemi._default = default
175147
return hemi
176148

177-
def get_coords(self, name=None):
178-
if name is None:
179-
name = self._default
180-
return self._meshes[name].coords
181-
182-
def get_triangles(self, name=None):
183-
if name is None:
184-
name = self._default
185-
return self._meshes[name].triangles
186-
187-
@property
188-
def n_coords(self):
189-
return self._meshes[self._default].vnum
190-
191-
@property
192-
def n_triangles(self):
193-
return self._meshes[self._default].fnum
194-
195149

196150
def test_FreeSurferHemisphere():
197151
lh = FreeSurferHemisphere.from_filename(FS_DATA / 'fsaverage/surf/lh.white')
@@ -206,6 +160,6 @@ def test_make_H5Geometry(tmp_path):
206160
h5geo.to_filename(tmp_path / 'geometry.h5')
207161

208162
rt_h5geo = H5Geometry.from_filename(tmp_path / 'geometry.h5')
209-
assert set(h5geo._meshes) == set(rt_h5geo._meshes)
163+
assert set(h5geo._coords) == set(rt_h5geo._coords)
210164
assert np.array_equal(lh.get_coords('white'), rt_h5geo.get_coords('white'))
211165
assert np.array_equal(lh.get_triangles(), rt_h5geo.get_triangles())

0 commit comments

Comments
 (0)