Skip to content

Commit 916bff9

Browse files
committed
ENH: Add pointset data structures [BIAP9]
1 parent 5f37398 commit 916bff9

File tree

2 files changed

+294
-0
lines changed

2 files changed

+294
-0
lines changed

nibabel/pointset.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import operator as op
2+
from functools import reduce
3+
4+
import numpy as np
5+
6+
from nibabel.affines import apply_affine
7+
8+
9+
class Pointset:
10+
def __init__(self, coords):
11+
self._coords = coords
12+
13+
@property
14+
def n_coords(self):
15+
"""Number of coordinates
16+
17+
Subclasses should override with more efficient implementations.
18+
"""
19+
return self.get_coords().shape[0]
20+
21+
def get_coords(self, name=None):
22+
"""Nx3 array of coordinates in RAS+ space"""
23+
return self._coords
24+
25+
26+
class TriangularMesh(Pointset):
27+
def __init__(self, mesh):
28+
if isinstance(mesh, tuple) and len(mesh) == 2:
29+
coords, self._triangles = mesh
30+
elif hasattr(mesh, 'coords') and hasattr(mesh, 'triangles'):
31+
coords = mesh.coords
32+
self._triangles = mesh.triangles
33+
elif hasattr(mesh, 'get_mesh'):
34+
coords, self._triangles = mesh.get_mesh()
35+
else:
36+
raise ValueError('Cannot interpret input as triangular mesh')
37+
super().__init__(coords)
38+
39+
@property
40+
def n_triangles(self):
41+
"""Number of faces
42+
43+
Subclasses should override with more efficient implementations.
44+
"""
45+
return self._triangles.shape[0]
46+
47+
def get_triangles(self):
48+
"""Mx3 array of indices into coordinate table"""
49+
return self._triangles
50+
51+
def get_mesh(self, name=None):
52+
return self.get_coords(name=name), self.get_triangles()
53+
54+
def get_names(self):
55+
"""List of surface names that can be passed to
56+
``get_{coords,triangles,mesh}``
57+
"""
58+
raise NotImplementedError
59+
60+
## This method is called for by the BIAP, but it now seems simpler to wait to
61+
## provide it until there are any proposed implementations
62+
# def decimate(self, *, n_coords=None, ratio=None):
63+
# """ Return a TriangularMesh with a smaller number of vertices that
64+
# preserves the geometry of the original """
65+
# # To be overridden when a format provides optimization opportunities
66+
# raise NotImplementedError
67+
68+
69+
class TriMeshFamily(TriangularMesh):
70+
def __init__(self, mapping, default=None):
71+
self._triangles = None
72+
self._coords = {}
73+
for name, mesh in dict(mapping).items():
74+
coords, triangles = TriangularMesh(mesh).get_mesh()
75+
if self._triangles is None:
76+
self._triangles = triangles
77+
self._coords[name] = coords
78+
79+
if default is None:
80+
default = next(iter(self._coords))
81+
self._default = default
82+
83+
def get_names(self):
84+
return list(self._coords)
85+
86+
def get_coords(self, name=None):
87+
if name is None:
88+
name = self._default
89+
return self._coords[name]
90+
91+
92+
class NdGrid(Pointset):
93+
"""
94+
Attributes
95+
----------
96+
shape : 3-tuple
97+
number of coordinates in each dimension of grid
98+
"""
99+
100+
def __init__(self, shape, affines):
101+
self.shape = tuple(shape)
102+
try:
103+
self._affines = dict(affines)
104+
except (TypeError, ValueError):
105+
self._affines = {'world': np.array(affines)}
106+
if 'voxels' not in self._affines:
107+
self._affines['voxels'] = np.eye(4, dtype=np.uint8)
108+
109+
def get_affine(self, name=None):
110+
"""4x4 array"""
111+
if name is None:
112+
name = next(iter(self._affines))
113+
return self._affines[name]
114+
115+
def get_coords(self, name=None):
116+
if name is None:
117+
name = next(iter(self._affines))
118+
aff = self.get_affine(name)
119+
dt = np.result_type(*(np.min_scalar_type(dim) for dim in self.shape))
120+
# This is pretty wasteful; we almost certainly want instead an
121+
# object that will retrieve a coordinate when indexed, but where
122+
# np.array(obj) returns this
123+
ijk_coords = np.array(list(np.ndindex(self.shape)), dtype=dt)
124+
return apply_affine(aff, ijk_coords)
125+
126+
@property
127+
def n_coords(self):
128+
return reduce(op.mul, self.shape)

nibabel/tests/test_pointset.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from pathlib import Path
2+
from unittest import skipUnless
3+
4+
import numpy as np
5+
6+
from nibabel import pointset as ps
7+
from nibabel.arrayproxy import ArrayProxy
8+
from nibabel.onetime import auto_attr
9+
from nibabel.optpkg import optional_package
10+
from nibabel.tests.nibabel_data import get_nibabel_data
11+
12+
h5, has_h5py, _ = optional_package('h5py')
13+
14+
FS_DATA = Path(get_nibabel_data()) / 'nitest-freesurfer'
15+
16+
17+
class H5ArrayProxy:
18+
def __init__(self, file_like, dataset_name):
19+
self.file_like = file_like
20+
self.dataset_name = dataset_name
21+
with h5.File(file_like, 'r') as h5f:
22+
arr = h5f[dataset_name]
23+
self._shape = arr.shape
24+
self._dtype = arr.dtype
25+
26+
@property
27+
def is_proxy(self):
28+
return True
29+
30+
@property
31+
def shape(self):
32+
return self._shape
33+
34+
@property
35+
def ndim(self):
36+
return len(self.shape)
37+
38+
@property
39+
def dtype(self):
40+
return self._dtype
41+
42+
def __array__(self, dtype=None):
43+
with h5.File(self.file_like, 'r') as h5f:
44+
return np.asanyarray(h5f[self.dataset_name], dtype)
45+
46+
def __getitem__(self, slicer):
47+
with h5.File(self.file_like, 'r') as h5f:
48+
return h5f[self.dataset_name][slicer]
49+
50+
51+
class H5Geometry(ps.TriMeshFamily):
52+
"""Simple Geometry file structure that combines a single topology
53+
with one or more coordinate sets
54+
"""
55+
56+
@classmethod
57+
def from_filename(klass, pathlike):
58+
meshes = {}
59+
with h5.File(pathlike, 'r') as h5f:
60+
triangles = H5ArrayProxy(pathlike, '/topology')
61+
for name in h5f['coordinates']:
62+
meshes[name] = (H5ArrayProxy(pathlike, f'/coordinates/{name}'), triangles)
63+
return klass(meshes)
64+
65+
def to_filename(self, pathlike):
66+
with h5.File(pathlike, 'w') as h5f:
67+
h5f.create_dataset('/topology', data=self.get_triangles())
68+
for name, coord in self._coords.items():
69+
h5f.create_dataset(f'/coordinates/{name}', data=coord)
70+
71+
72+
class FSGeometryProxy:
73+
def __init__(self, pathlike):
74+
self._file_like = str(Path(pathlike))
75+
self._offset = None
76+
self._vnum = None
77+
self._fnum = None
78+
79+
def _peek(self):
80+
from nibabel.freesurfer.io import _fread3
81+
82+
with open(self._file_like, 'rb') as fobj:
83+
magic = _fread3(fobj)
84+
if magic != 16777214:
85+
raise NotImplementedError('Triangle files only!')
86+
fobj.readline()
87+
fobj.readline()
88+
self._vnum = np.fromfile(fobj, '>i4', 1)[0]
89+
self._fnum = np.fromfile(fobj, '>i4', 1)[0]
90+
self._offset = fobj.tell()
91+
92+
@property
93+
def vnum(self):
94+
if self._vnum is None:
95+
self._peek()
96+
return self._vnum
97+
98+
@property
99+
def fnum(self):
100+
if self._fnum is None:
101+
self._peek()
102+
return self._fnum
103+
104+
@property
105+
def offset(self):
106+
if self._offset is None:
107+
self._peek()
108+
return self._offset
109+
110+
@auto_attr
111+
def coords(self):
112+
ap = ArrayProxy(self._file_like, ((self.vnum, 3), '>f4', self.offset))
113+
ap.order = 'C'
114+
return ap
115+
116+
@auto_attr
117+
def triangles(self):
118+
offset = self.offset + 12 * self.vnum
119+
ap = ArrayProxy(self._file_like, ((self.fnum, 3), '>i4', offset))
120+
ap.order = 'C'
121+
return ap
122+
123+
124+
class FreeSurferHemisphere(ps.TriMeshFamily):
125+
@classmethod
126+
def from_filename(klass, pathlike):
127+
path = Path(pathlike)
128+
hemi, default = path.name.split('.')
129+
mesh_names = (
130+
'orig',
131+
'white',
132+
'smoothwm',
133+
'pial',
134+
'inflated',
135+
'sphere',
136+
'midthickness',
137+
'graymid',
138+
) # Often created
139+
if default not in mesh_names:
140+
mesh_names.append(default)
141+
meshes = {}
142+
for mesh in mesh_names:
143+
fpath = path.parent / f'{hemi}.{mesh}'
144+
if fpath.exists():
145+
meshes[mesh] = FSGeometryProxy(fpath)
146+
hemi = klass(meshes)
147+
hemi._default = default
148+
return hemi
149+
150+
151+
def test_FreeSurferHemisphere():
152+
lh = FreeSurferHemisphere.from_filename(FS_DATA / 'fsaverage/surf/lh.white')
153+
assert lh.n_coords == 163842
154+
assert lh.n_triangles == 327680
155+
156+
157+
@skipUnless(has_h5py, reason='Test requires h5py')
158+
def test_make_H5Geometry(tmp_path):
159+
lh = FreeSurferHemisphere.from_filename(FS_DATA / 'fsaverage/surf/lh.white')
160+
h5geo = H5Geometry({name: lh.get_mesh(name) for name in ('white', 'pial')})
161+
h5geo.to_filename(tmp_path / 'geometry.h5')
162+
163+
rt_h5geo = H5Geometry.from_filename(tmp_path / 'geometry.h5')
164+
assert set(h5geo._coords) == set(rt_h5geo._coords)
165+
assert np.array_equal(lh.get_coords('white'), rt_h5geo.get_coords('white'))
166+
assert np.array_equal(lh.get_triangles(), rt_h5geo.get_triangles())

0 commit comments

Comments
 (0)