Skip to content

Commit b70a4d9

Browse files
committed
Merge branches 'enh/copyarrayproxy', 'enh/xml-kwargs' and 'enh/triangular_mesh' into biap9-rebase
3 parents 81b1033 + cea2f6c + 368c145 commit b70a4d9

File tree

4 files changed

+348
-15
lines changed

4 files changed

+348
-15
lines changed

nibabel/gifti/gifti.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,7 @@ def _to_xml_element(self):
852852
GIFTI.append(dar._to_xml_element())
853853
return GIFTI
854854

855-
def to_xml(self, enc='utf-8', *, mode='strict') -> bytes:
855+
def to_xml(self, enc='utf-8', *, mode='strict', **kwargs) -> bytes:
856856
"""Return XML corresponding to image content"""
857857
if mode == 'strict':
858858
if any(arr.datatype not in GIFTI_DTYPES for arr in self.darrays):
@@ -882,7 +882,7 @@ def to_xml(self, enc='utf-8', *, mode='strict') -> bytes:
882882
header = b"""<?xml version="1.0" encoding="UTF-8"?>
883883
<!DOCTYPE GIFTI SYSTEM "http://www.nitrc.org/frs/download.php/115/gifti.dtd">
884884
"""
885-
return header + super().to_xml(enc)
885+
return header + super().to_xml(enc, **kwargs)
886886

887887
# Avoid the indirection of going through to_file_map
888888
def to_bytes(self, enc='utf-8', *, mode='strict'):

nibabel/pointset.py

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,12 @@ def __array__(self, dtype: _DType, /) -> np.ndarray[ty.Any, _DType]:
4848
... # pragma: no cover
4949

5050

51-
@dataclass
51+
class HasMeshAttrs(ty.Protocol):
52+
coordinates: CoordinateArray
53+
triangles: CoordinateArray
54+
55+
56+
@dataclass(init=False)
5257
class Pointset:
5358
"""A collection of points described by coordinates.
5459
@@ -65,7 +70,7 @@ class Pointset:
6570

6671
coordinates: CoordinateArray
6772
affine: np.ndarray
68-
homogeneous: bool = False
73+
homogeneous: bool
6974

7075
# Force use of __rmatmul__ with numpy arrays
7176
__array_priority__ = 99
@@ -144,6 +149,82 @@ def get_coords(self, *, as_homogeneous: bool = False):
144149
return coords
145150

146151

152+
@dataclass(init=False)
153+
class TriangularMesh(Pointset):
154+
triangles: CoordinateArray
155+
156+
def __init__(
157+
self,
158+
coordinates: CoordinateArray,
159+
triangles: CoordinateArray,
160+
affine: np.ndarray | None = None,
161+
homogeneous: bool = False,
162+
):
163+
super().__init__(coordinates, affine=affine, homogeneous=homogeneous)
164+
self.triangles = triangles
165+
166+
@classmethod
167+
def from_tuple(
168+
cls,
169+
mesh: tuple[CoordinateArray, CoordinateArray],
170+
affine: np.ndarray | None = None,
171+
homogeneous: bool = False,
172+
**kwargs,
173+
) -> Self:
174+
return cls(mesh[0], mesh[1], affine=affine, homogeneous=homogeneous, **kwargs)
175+
176+
@classmethod
177+
def from_object(
178+
cls,
179+
mesh: HasMeshAttrs,
180+
affine: np.ndarray | None = None,
181+
homogeneous: bool = False,
182+
**kwargs,
183+
) -> Self:
184+
return cls(
185+
mesh.coordinates, mesh.triangles, affine=affine, homogeneous=homogeneous, **kwargs
186+
)
187+
188+
@property
189+
def n_triangles(self):
190+
"""Number of faces
191+
192+
Subclasses should override with more efficient implementations.
193+
"""
194+
return self.triangles.shape[0]
195+
196+
def get_triangles(self):
197+
"""Mx3 array of indices into coordinate table"""
198+
return np.asanyarray(self.triangles)
199+
200+
def get_mesh(self, *, as_homogeneous: bool = False):
201+
return self.get_coords(as_homogeneous=as_homogeneous), self.get_triangles()
202+
203+
204+
class CoordinateFamilyMixin(Pointset):
205+
def __init__(self, *args, name='original', **kwargs):
206+
mapping = kwargs.pop('mapping', {})
207+
super().__init__(*args, **kwargs)
208+
self._coords = {name: self.coordinates, **mapping}
209+
210+
def get_names(self):
211+
"""List of surface names that can be passed to :meth:`with_name`"""
212+
return list(self._coords)
213+
214+
def with_name(self, name: str) -> Self:
215+
new_coords = self._coords[name]
216+
if new_coords is self.coordinates:
217+
return self
218+
# Make a copy, preserving all dataclass fields
219+
new = replace(self, coordinates=new_coords)
220+
# Conserve exact _coords mapping
221+
new._coords = self._coords
222+
return new
223+
224+
def add_coordinates(self, name, coordinates):
225+
self._coords[name] = coordinates
226+
227+
147228
class Grid(Pointset):
148229
r"""A regularly-spaced collection of coordinates
149230

nibabel/tests/test_pointset.py

Lines changed: 246 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import namedtuple
12
from math import prod
23
from pathlib import Path
34
from unittest import skipUnless
@@ -12,7 +13,7 @@
1213
from nibabel.onetime import auto_attr
1314
from nibabel.optpkg import optional_package
1415
from nibabel.spatialimages import SpatialImage
15-
from nibabel.tests.nibabel_data import get_nibabel_data
16+
from nibabel.tests.nibabel_data import get_nibabel_data, needs_nibabel_data
1617

1718
h5, has_h5py, _ = optional_package('h5py')
1819

@@ -182,3 +183,247 @@ def test_to_mask(self):
182183
],
183184
)
184185
assert np.array_equal(mask_img.affine, np.eye(4))
186+
187+
188+
class TestTriangularMeshes:
189+
def test_api(self):
190+
# Tetrahedron
191+
coords = np.array(
192+
[
193+
[0.0, 0.0, 0.0],
194+
[0.0, 0.0, 1.0],
195+
[0.0, 1.0, 0.0],
196+
[1.0, 0.0, 0.0],
197+
]
198+
)
199+
triangles = np.array(
200+
[
201+
[0, 2, 1],
202+
[0, 3, 2],
203+
[0, 1, 3],
204+
[1, 2, 3],
205+
]
206+
)
207+
208+
mesh = namedtuple('mesh', ('coordinates', 'triangles'))(coords, triangles)
209+
210+
tm1 = ps.TriangularMesh(coords, triangles)
211+
tm2 = ps.TriangularMesh.from_tuple(mesh)
212+
tm3 = ps.TriangularMesh.from_object(mesh)
213+
214+
assert np.allclose(tm1.affine, np.eye(4))
215+
assert np.allclose(tm2.affine, np.eye(4))
216+
assert np.allclose(tm3.affine, np.eye(4))
217+
218+
assert tm1.homogeneous is False
219+
assert tm2.homogeneous is False
220+
assert tm3.homogeneous is False
221+
222+
assert (tm1.n_coords, tm1.dim) == (4, 3)
223+
assert (tm2.n_coords, tm2.dim) == (4, 3)
224+
assert (tm3.n_coords, tm3.dim) == (4, 3)
225+
226+
assert tm1.n_triangles == 4
227+
assert tm2.n_triangles == 4
228+
assert tm3.n_triangles == 4
229+
230+
out_coords, out_tris = tm1.get_mesh()
231+
# Currently these are the exact arrays, but I don't think we should
232+
# bake that assumption into the tests
233+
assert np.allclose(out_coords, coords)
234+
assert np.allclose(out_tris, triangles)
235+
236+
237+
class TestCoordinateFamilyMixin(TestPointsets):
238+
def test_names(self):
239+
coords = np.array(
240+
[
241+
[0.0, 0.0, 0.0],
242+
[0.0, 0.0, 1.0],
243+
[0.0, 1.0, 0.0],
244+
[1.0, 0.0, 0.0],
245+
]
246+
)
247+
cfm = ps.CoordinateFamilyMixin(coords)
248+
249+
assert cfm.get_names() == ['original']
250+
assert np.allclose(cfm.with_name('original').coordinates, coords)
251+
252+
cfm.add_coordinates('shifted', coords + 1)
253+
assert set(cfm.get_names()) == {'original', 'shifted'}
254+
shifted = cfm.with_name('shifted')
255+
assert np.allclose(shifted.coordinates, coords + 1)
256+
assert set(shifted.get_names()) == {'original', 'shifted'}
257+
original = shifted.with_name('original')
258+
assert np.allclose(original.coordinates, coords)
259+
260+
# Avoid duplicating objects
261+
assert original.with_name('original') is original
262+
# But don't try too hard
263+
assert original.with_name('original') is not cfm
264+
265+
# with_name() preserves the exact coordinate mapping of the source object.
266+
# Modifications of one are immediately available to all others.
267+
# This is currently an implementation detail, and the expectation is that
268+
# a family will be created once and then queried, but this behavior could
269+
# potentially become confusing or relied upon.
270+
# Change with care.
271+
shifted.add_coordinates('shifted-again', coords + 2)
272+
shift2 = shifted.with_name('shifted-again')
273+
shift3 = cfm.with_name('shifted-again')
274+
275+
276+
class H5ArrayProxy:
277+
def __init__(self, file_like, dataset_name):
278+
self.file_like = file_like
279+
self.dataset_name = dataset_name
280+
with h5.File(file_like, 'r') as h5f:
281+
arr = h5f[dataset_name]
282+
self._shape = arr.shape
283+
self._dtype = arr.dtype
284+
285+
@property
286+
def is_proxy(self):
287+
return True
288+
289+
@property
290+
def shape(self):
291+
return self._shape
292+
293+
@property
294+
def ndim(self):
295+
return len(self.shape)
296+
297+
@property
298+
def dtype(self):
299+
return self._dtype
300+
301+
def __array__(self, dtype=None):
302+
with h5.File(self.file_like, 'r') as h5f:
303+
return np.asanyarray(h5f[self.dataset_name], dtype)
304+
305+
def __getitem__(self, slicer):
306+
with h5.File(self.file_like, 'r') as h5f:
307+
return h5f[self.dataset_name][slicer]
308+
309+
310+
class H5Geometry(ps.CoordinateFamilyMixin, ps.TriangularMesh):
311+
"""Simple Geometry file structure that combines a single topology
312+
with one or more coordinate sets
313+
"""
314+
315+
@classmethod
316+
def from_filename(klass, pathlike):
317+
coords = {}
318+
with h5.File(pathlike, 'r') as h5f:
319+
triangles = H5ArrayProxy(pathlike, '/topology')
320+
for name in h5f['coordinates']:
321+
coords[name] = H5ArrayProxy(pathlike, f'/coordinates/{name}')
322+
self = klass(next(iter(coords.values())), triangles, mapping=coords)
323+
return self
324+
325+
def to_filename(self, pathlike):
326+
with h5.File(pathlike, 'w') as h5f:
327+
h5f.create_dataset('/topology', data=self.get_triangles())
328+
for name, coord in self._coords.items():
329+
h5f.create_dataset(f'/coordinates/{name}', data=coord)
330+
331+
332+
class FSGeometryProxy:
333+
def __init__(self, pathlike):
334+
self._file_like = str(Path(pathlike))
335+
self._offset = None
336+
self._vnum = None
337+
self._fnum = None
338+
339+
def _peek(self):
340+
from nibabel.freesurfer.io import _fread3
341+
342+
with open(self._file_like, 'rb') as fobj:
343+
magic = _fread3(fobj)
344+
if magic != 16777214:
345+
raise NotImplementedError('Triangle files only!')
346+
fobj.readline()
347+
fobj.readline()
348+
self._vnum = np.fromfile(fobj, '>i4', 1)[0]
349+
self._fnum = np.fromfile(fobj, '>i4', 1)[0]
350+
self._offset = fobj.tell()
351+
352+
@property
353+
def vnum(self):
354+
if self._vnum is None:
355+
self._peek()
356+
return self._vnum
357+
358+
@property
359+
def fnum(self):
360+
if self._fnum is None:
361+
self._peek()
362+
return self._fnum
363+
364+
@property
365+
def offset(self):
366+
if self._offset is None:
367+
self._peek()
368+
return self._offset
369+
370+
@auto_attr
371+
def coordinates(self):
372+
return ArrayProxy(self._file_like, ((self.vnum, 3), '>f4', self.offset), order='C')
373+
374+
@auto_attr
375+
def triangles(self):
376+
return ArrayProxy(
377+
self._file_like,
378+
((self.fnum, 3), '>i4', self.offset + 12 * self.vnum),
379+
order='C',
380+
)
381+
382+
383+
class FreeSurferHemisphere(ps.CoordinateFamilyMixin, ps.TriangularMesh):
384+
@classmethod
385+
def from_filename(klass, pathlike):
386+
path = Path(pathlike)
387+
hemi, default = path.name.split('.')
388+
self = klass.from_object(FSGeometryProxy(path), name=default)
389+
mesh_names = (
390+
'orig',
391+
'white',
392+
'smoothwm',
393+
'pial',
394+
'inflated',
395+
'sphere',
396+
'midthickness',
397+
'graymid',
398+
) # Often created
399+
400+
for mesh in mesh_names:
401+
if mesh != default:
402+
fpath = path.parent / f'{hemi}.{mesh}'
403+
if fpath.exists():
404+
self.add_coordinates(mesh, FSGeometryProxy(fpath).coordinates)
405+
return self
406+
407+
408+
@needs_nibabel_data('nitest-freesurfer')
409+
def test_FreeSurferHemisphere():
410+
lh = FreeSurferHemisphere.from_filename(FS_DATA / 'fsaverage/surf/lh.white')
411+
assert lh.n_coords == 163842
412+
assert lh.n_triangles == 327680
413+
414+
415+
@skipUnless(has_h5py, reason='Test requires h5py')
416+
@needs_nibabel_data('nitest-freesurfer')
417+
def test_make_H5Geometry(tmp_path):
418+
lh = FreeSurferHemisphere.from_filename(FS_DATA / 'fsaverage/surf/lh.white')
419+
h5geo = H5Geometry.from_object(lh)
420+
for name in ('white', 'pial'):
421+
h5geo.add_coordinates(name, lh.with_name(name).coordinates)
422+
h5geo.to_filename(tmp_path / 'geometry.h5')
423+
424+
rt_h5geo = H5Geometry.from_filename(tmp_path / 'geometry.h5')
425+
assert set(h5geo._coords) == set(rt_h5geo._coords)
426+
assert np.array_equal(
427+
lh.with_name('white').get_coords(), rt_h5geo.with_name('white').get_coords()
428+
)
429+
assert np.array_equal(lh.get_triangles(), rt_h5geo.get_triangles())

0 commit comments

Comments
 (0)