Skip to content

Commit 0c9d4ee

Browse files
committed
Merge remote-tracking branch 'oesteban/enh/MeshWarp'
* oesteban/enh/MeshWarp: remove unncecessary pass statements describe the input displacement field fix conflicts Added ETSConfig to avoid display errors fixed error in doctest Edited CHANGES, added doctets, fixed errors warping meshes
2 parents f882fd2 + 095fc7c commit 0c9d4ee

File tree

2 files changed

+131
-5
lines changed

2 files changed

+131
-5
lines changed

CHANGES

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ Next release
22
============
33

44
* API: Change how hash values are computed (https://github.com/nipy/nipype/pull/1174)
5+
* ENH: New algorithm: mesh.WarpPoints applies displacements fields to point sets
6+
(https://github.com/nipy/nipype/pull/889).
57
* ENH: New interfaces for MRTrix3 (https://github.com/nipy/nipype/pull/1126)
68
* ENH: New option in afni.3dRefit - zdel, ydel, zdel etc. (https://github.com/nipy/nipype/pull/1079)
79
* FIX: ants.Registration composite transform outputs are no longer returned as lists (https://github.com/nipy/nipype/pull/1183)

nipype/algorithms/mesh.py

Lines changed: 129 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,140 @@
2525
iflogger = logging.getLogger('interface')
2626

2727

28+
class WarpPointsInputSpec(BaseInterfaceInputSpec):
29+
points = File(exists=True, mandatory=True,
30+
desc=('file containing the point set'))
31+
warp = File(exists=True, mandatory=True,
32+
desc=('dense deformation field to be applied'))
33+
interp = traits.Enum('cubic', 'nearest', 'linear', usedefault=True,
34+
mandatory=True, desc='interpolation')
35+
out_points = File(name_source='points', name_template='%s_warped',
36+
output_name='out_points', keep_extension=True,
37+
desc='the warped point set')
38+
39+
40+
class WarpPointsOutputSpec(TraitedSpec):
41+
out_points = File(desc='the warped point set')
42+
43+
44+
class WarpPoints(BaseInterface):
45+
46+
"""
47+
Applies a displacement field to a point set given in vtk format.
48+
Any discrete deformation field, given in physical coordinates and
49+
which volume covers the extent of the vtk point set, is a valid
50+
``warp`` file. FSL interfaces are compatible, for instance any
51+
field computed with :class:`nipype.interfaces.fsl.utils.ConvertWarp`.
52+
53+
Example
54+
-------
55+
56+
>>> from nipype.algorithms.mesh import WarpPoints
57+
>>> wp = WarpPoints()
58+
>>> wp.inputs.points = 'surf1.vtk'
59+
>>> wp.inputs.warp = 'warpfield.nii'
60+
>>> res = wp.run() # doctest: +SKIP
61+
"""
62+
input_spec = WarpPointsInputSpec
63+
output_spec = WarpPointsOutputSpec
64+
_redirect_x = True
65+
66+
def _gen_fname(self, in_file, suffix='generated', ext=None):
67+
import os.path as op
68+
69+
fname, fext = op.splitext(op.basename(in_file))
70+
71+
if fext == '.gz':
72+
fname, fext2 = op.splitext(fname)
73+
fext = fext2+fext
74+
75+
if ext is None:
76+
ext = fext
77+
78+
if ext[0] == '.':
79+
ext = ext[1:]
80+
return op.abspath('%s_%s.%s' % (fname, suffix, ext))
81+
82+
def _run_interface(self, runtime):
83+
vtk_major = 6
84+
try:
85+
import vtk
86+
vtk_major = vtk.VTK_MAJOR_VERSION
87+
except ImportError:
88+
iflogger.warn(('python-vtk could not be imported'))
89+
90+
try:
91+
from tvtk.api import tvtk
92+
except ImportError:
93+
raise ImportError('Interface requires tvtk')
94+
95+
try:
96+
from enthought.etsconfig.api import ETSConfig
97+
ETSConfig.toolkit = 'null'
98+
except ImportError:
99+
iflogger.warn(('ETS toolkit could not be imported'))
100+
except ValueError:
101+
iflogger.warn(('ETS toolkit could not be set to null'))
102+
103+
import nibabel as nb
104+
import numpy as np
105+
from scipy import ndimage
106+
107+
r = tvtk.PolyDataReader(file_name=self.inputs.points)
108+
r.update()
109+
mesh = r.output
110+
points = np.array(mesh.points)
111+
warp_dims = nb.funcs.four_to_three(nb.load(self.inputs.warp))
112+
113+
affine = warp_dims[0].get_affine()
114+
voxsize = warp_dims[0].get_header().get_zooms()
115+
vox2ras = affine[0:3, 0:3]
116+
ras2vox = np.linalg.inv(vox2ras)
117+
origin = affine[0:3, 3]
118+
voxpoints = np.array([np.dot(ras2vox,
119+
(p-origin)) for p in points])
120+
121+
warps = []
122+
for axis in warp_dims:
123+
wdata = axis.get_data()
124+
if np.any(wdata != 0):
125+
126+
warp = ndimage.map_coordinates(wdata,
127+
voxpoints.transpose())
128+
else:
129+
warp = np.zeros((points.shape[0],))
130+
131+
warps.append(warp)
132+
133+
disps = np.squeeze(np.dstack(warps))
134+
newpoints = [p+d for p, d in zip(points, disps)]
135+
mesh.points = newpoints
136+
w = tvtk.PolyDataWriter()
137+
if vtk_major <= 5:
138+
w.input = mesh
139+
else:
140+
w.set_input_data_object(mesh)
141+
142+
w.file_name = self._gen_fname(self.inputs.points,
143+
suffix='warped',
144+
ext='.vtk')
145+
w.write()
146+
return runtime
147+
148+
def _list_outputs(self):
149+
outputs = self._outputs().get()
150+
outputs['out_points'] = self._gen_fname(self.inputs.points,
151+
suffix='warped',
152+
ext='.vtk')
153+
return outputs
154+
155+
28156
class ComputeMeshWarpInputSpec(BaseInterfaceInputSpec):
29157
surface1 = File(exists=True, mandatory=True,
30158
desc=('Reference surface (vtk format) to which compute '
31159
'distance.'))
32160
surface2 = File(exists=True, mandatory=True,
161+
33162
desc=('Test surface (vtk format) from which compute '
34163
'distance.'))
35164
metric = traits.Enum('euclidean', 'sqeuclidean', usedefault=True,
@@ -101,10 +230,8 @@ def _run_interface(self, runtime):
101230
ETSConfig.toolkit = 'null'
102231
except ImportError:
103232
iflogger.warn(('ETS toolkit could not be imported'))
104-
pass
105233
except ValueError:
106234
iflogger.warn(('ETS toolkit is already set'))
107-
pass
108235

109236
r1 = tvtk.PolyDataReader(file_name=self.inputs.surface1)
110237
r2 = tvtk.PolyDataReader(file_name=self.inputs.surface2)
@@ -124,7 +251,6 @@ def _run_interface(self, runtime):
124251
errvector = nla.norm(diff, axis=1)
125252
except TypeError: # numpy < 1.9
126253
errvector = np.apply_along_axis(nla.norm, 1, diff)
127-
pass
128254

129255
if self.inputs.metric == 'sqeuclidean':
130256
errvector = errvector ** 2
@@ -235,10 +361,8 @@ def _run_interface(self, runtime):
235361
ETSConfig.toolkit = 'null'
236362
except ImportError:
237363
iflogger.warn(('ETS toolkit could not be imported'))
238-
pass
239364
except ValueError:
240365
iflogger.warn(('ETS toolkit is already set'))
241-
pass
242366

243367
r1 = tvtk.PolyDataReader(file_name=self.inputs.in_surf)
244368
vtk1 = r1.output

0 commit comments

Comments
 (0)