Skip to content

Commit dcc8fa1

Browse files
committed
RF: Add TriMeshFamily to formalize 1-mesh-to-many-coords structure
1 parent ba2af8f commit dcc8fa1

File tree

2 files changed

+49
-57
lines changed

2 files changed

+49
-57
lines changed

nibabel/pointset.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,49 @@
77

88

99
class Pointset:
10+
def __init__(self, coords):
11+
self._coords = coords
12+
1013
@property
1114
def n_coords(self):
1215
""" Number of coordinates
1316
1417
Subclasses should override with more efficient implementations.
1518
"""
16-
return len(self.get_coords())
19+
return self.get_coords().shape[0]
1720

1821
def get_coords(self, name=None):
1922
""" Nx3 array of coordinates in RAS+ space """
20-
raise NotImplementedError
23+
return self._coords
2124

2225

2326
class TriangularMesh(Pointset):
27+
def __init__(self, mesh):
28+
if isinstance(mesh, tuple) and len(mesh) == 2:
29+
coords, self._triangles = mesh
30+
elif hasattr(mesh, 'coords') and hasattr(mesh, 'triangles'):
31+
coords = mesh.coords
32+
self._triangles = mesh.triangles
33+
elif hasattr(mesh, 'get_mesh'):
34+
coords, self._triangles = mesh.get_mesh()
35+
else:
36+
raise ValueError("Cannot interpret input as triangular mesh")
37+
super().__init__(coords)
38+
2439
@property
2540
def n_triangles(self):
2641
""" Number of faces
2742
2843
Subclasses should override with more efficient implementations.
2944
"""
30-
return len(self.get_triangles())
45+
return self._triangles.shape[0]
3146

32-
def get_triangles(self, name=None):
47+
def get_triangles(self):
3348
""" Mx3 array of indices into coordinate table """
34-
raise NotImplementedError
49+
return self._triangles
3550

3651
def get_mesh(self, name=None):
37-
return self.get_coords(name=name), self.get_triangles(name=name)
52+
return self.get_coords(name=name), self.get_triangles()
3853

3954
def get_names(self):
4055
""" List of surface names that can be passed to
@@ -51,6 +66,29 @@ def get_names(self):
5166
# raise NotImplementedError
5267

5368

69+
class TriMeshFamily(TriangularMesh):
70+
def __init__(self, mapping, default=None):
71+
self._triangles = None
72+
self._coords = {}
73+
for name, mesh in dict(mapping).items():
74+
coords, triangles = TriangularMesh(mesh).get_mesh()
75+
if self._triangles is None:
76+
self._triangles = triangles
77+
self._coords[name] = coords
78+
79+
if default is None:
80+
default = next(iter(self._coords))
81+
self._default = default
82+
83+
def get_names(self):
84+
return list(self._coords)
85+
86+
def get_coords(self, name=None):
87+
if name is None:
88+
name = self._default
89+
return self._coords[name]
90+
91+
5492
class NdGrid(Pointset):
5593
"""
5694
Attributes

nibabel/tests/test_pointset.py

Lines changed: 5 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,10 @@ def __getitem__(self, slicer):
4747
return h5f[self.dataset_name][slicer]
4848

4949

50-
class H5Geometry(ps.TriangularMesh):
50+
class H5Geometry(ps.TriMeshFamily):
5151
"""Simple Geometry file structure that combines a single topology
5252
with one or more coordinate sets
5353
"""
54-
def __init__(self, meshes):
55-
self._meshes = meshes
56-
5754
@classmethod
5855
def from_filename(klass, pathlike):
5956
meshes = {}
@@ -64,33 +61,11 @@ def from_filename(klass, pathlike):
6461
return klass(meshes)
6562

6663
def to_filename(self, pathlike):
67-
topology = None
68-
coordinates = {}
69-
for name, mesh in self._meshes.items():
70-
coords, faces = mesh
71-
if topology is None:
72-
topology = faces
73-
elif not np.array_equal(faces, topology):
74-
raise ValueError("Inconsistent topology")
75-
coordinates[name] = coords
76-
7764
with h5.File(pathlike, "w") as h5f:
78-
h5f.create_dataset("/topology", data=topology)
79-
for name, coord in coordinates.items():
65+
h5f.create_dataset("/topology", data=self.get_triangles())
66+
for name, coord in self._coords.items():
8067
h5f.create_dataset(f"/coordinates/{name}", data=coord)
8168

82-
def get_coords(self, name=None):
83-
if name is None:
84-
name = next(iter(self._meshes))
85-
coords, _ = self._meshes[name]
86-
return coords
87-
88-
def get_triangles(self, name=None):
89-
if name is None:
90-
name = next(iter(self._meshes))
91-
_, triangles = self._meshes[name]
92-
return triangles
93-
9469

9570
class FSGeometryProxy:
9671
def __init__(self, pathlike):
@@ -143,10 +118,7 @@ def triangles(self):
143118
return ap
144119

145120

146-
class FreeSurferHemisphere(ps.TriangularMesh):
147-
def __init__(self, meshes):
148-
self._meshes = meshes
149-
121+
class FreeSurferHemisphere(ps.TriMeshFamily):
150122
@classmethod
151123
def from_filename(klass, pathlike):
152124
path = Path(pathlike)
@@ -165,24 +137,6 @@ def from_filename(klass, pathlike):
165137
hemi._default = default
166138
return hemi
167139

168-
def get_coords(self, name=None):
169-
if name is None:
170-
name = self._default
171-
return self._meshes[name].coords
172-
173-
def get_triangles(self, name=None):
174-
if name is None:
175-
name = self._default
176-
return self._meshes[name].triangles
177-
178-
@property
179-
def n_coords(self):
180-
return self._meshes[self._default].vnum
181-
182-
@property
183-
def n_triangles(self):
184-
return self._meshes[self._default].fnum
185-
186140

187141
def test_FreeSurferHemisphere():
188142
lh = FreeSurferHemisphere.from_filename(FS_DATA / 'fsaverage/surf/lh.white')
@@ -197,6 +151,6 @@ def test_make_H5Geometry(tmp_path):
197151
h5geo.to_filename(tmp_path / "geometry.h5")
198152

199153
rt_h5geo = H5Geometry.from_filename(tmp_path / "geometry.h5")
200-
assert set(h5geo._meshes) == set(rt_h5geo._meshes)
154+
assert set(h5geo._coords) == set(rt_h5geo._coords)
201155
assert np.array_equal(lh.get_coords('white'), rt_h5geo.get_coords('white'))
202156
assert np.array_equal(lh.get_triangles(), rt_h5geo.get_triangles())

0 commit comments

Comments
 (0)