Skip to content

Commit b7ba47f

Browse files
committed
use tvtk appropriate tools whenever possible
1 parent 9339e5f commit b7ba47f

File tree

3 files changed

+32
-51
lines changed

3 files changed

+32
-51
lines changed

nipype/algorithms/mesh.py

Lines changed: 21 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,12 @@ def _run_interface(self, runtime):
144144
import nibabel as nb
145145
import numpy as np
146146
from scipy import ndimage
147+
from tvtk.common import configure_input_data
148+
from tvtk.common import is_old_pipeline as vtk_old
147149

148150
r = tvtk.PolyDataReader(file_name=self.inputs.points)
149151
r.update()
150-
mesh = r.output
152+
mesh = r.output if vtk_old() else r.get_output()
151153
points = np.array(mesh.points)
152154
warp_dims = nb.funcs.four_to_three(nb.load(self.inputs.warp))
153155

@@ -174,21 +176,14 @@ def _run_interface(self, runtime):
174176
newpoints = [p + d for p, d in zip(points, disps)]
175177
mesh.points = newpoints
176178
w = tvtk.PolyDataWriter()
177-
if self.vtk_version()[0] < 6:
178-
w.input = mesh
179-
else:
180-
w.set_input_data_object(mesh)
181-
182-
w.file_name = self._gen_fname(self.inputs.points,
183-
suffix='warped',
184-
ext='.vtk')
179+
configure_input_data(w, mesh)
180+
w.file_name = self._gen_fname(self.inputs.points, suffix='warped', ext='.vtk')
185181
w.write()
186182
return runtime
187183

188184
def _list_outputs(self):
189185
outputs = self._outputs().get()
190-
outputs['out_points'] = self._gen_fname(self.inputs.points,
191-
suffix='warped',
186+
outputs['out_points'] = self._gen_fname(self.inputs.points, suffix='warped',
192187
ext='.vtk')
193188
return outputs
194189

@@ -258,10 +253,13 @@ def _triangle_area(self, A, B, C):
258253
return area
259254

260255
def _run_interface(self, runtime):
256+
from tvtk.common import configure_input_data
257+
from tvtk.common import is_old_pipeline as vtk_old
258+
261259
r1 = tvtk.PolyDataReader(file_name=self.inputs.surface1)
262260
r2 = tvtk.PolyDataReader(file_name=self.inputs.surface2)
263-
vtk1 = r1.output
264-
vtk2 = r2.output
261+
vtk1 = r1.output if vtk_old() else r1.get_output()
262+
vtk2 = r2.output if vtk_old() else r2.get_output()
265263
r1.update()
266264
r2.update()
267265
assert(len(vtk1.points) == len(vtk2.points))
@@ -305,12 +303,7 @@ def _run_interface(self, runtime):
305303
out_mesh.point_data.vectors.name = 'warpings'
306304
writer = tvtk.PolyDataWriter(
307305
file_name=op.abspath(self.inputs.out_warp))
308-
309-
if self.vtk_version()[0] < 6:
310-
writer.input = out_mesh
311-
else:
312-
writer.set_input_data_object(out_mesh)
313-
306+
configure_input_data(writer, out_mesh)
314307
writer.write()
315308

316309
self._distance = np.average(errvector, weights=weights)
@@ -379,8 +372,11 @@ class MeshWarpMaths(TVTKBaseInterface):
379372
output_spec = MeshWarpMathsOutputSpec
380373

381374
def _run_interface(self, runtime):
375+
from tvtk.common import configure_input_data
376+
from tvtk.common import is_old_pipeline as vtk_old
377+
382378
r1 = tvtk.PolyDataReader(file_name=self.inputs.in_surf)
383-
vtk1 = r1.output
379+
vtk1 = r1.output if vtk_old() else r1.get_output()
384380
r1.update()
385381
points1 = np.array(vtk1.points)
386382

@@ -392,7 +388,7 @@ def _run_interface(self, runtime):
392388

393389
if isinstance(operator, string_types):
394390
r2 = tvtk.PolyDataReader(file_name=self.inputs.surface2)
395-
vtk2 = r2.output
391+
vtk2 = r2.output if vtk_old() else r2.get_output()
396392
r2.update()
397393
assert(len(points1) == len(vtk2.points))
398394

@@ -425,25 +421,15 @@ def _run_interface(self, runtime):
425421
warping /= opfield
426422

427423
vtk1.point_data.vectors = warping
428-
writer = tvtk.PolyDataWriter(
429-
file_name=op.abspath(self.inputs.out_warp))
430-
if self.vtk_version()[0] < 6:
431-
writer.input = vtk1
432-
else:
433-
writer.set_input_data_object(vtk1)
424+
writer = tvtk.PolyDataWriter(file_name=op.abspath(self.inputs.out_warp))
425+
configure_input_data(writer, vtk1)
434426
writer.write()
435427

436428
vtk1.point_data.vectors = None
437429
vtk1.points = points1 + warping
438-
writer = tvtk.PolyDataWriter(
439-
file_name=op.abspath(self.inputs.out_file))
440-
441-
if self.vtk_version()[0] < 6:
442-
writer.input = vtk1
443-
else:
444-
writer.set_input_data_object(vtk1)
430+
writer = tvtk.PolyDataWriter(file_name=op.abspath(self.inputs.out_file))
431+
configure_input_data(writer, vtk1)
445432
writer.write()
446-
447433
return runtime
448434

449435
def _list_outputs(self):

nipype/algorithms/tests/test_mesh_ops.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_ident_distances():
2121
curdir = os.getcwd()
2222
os.chdir(tempdir)
2323

24-
if m.no_tvtk():
24+
if m.Info.no_tvtk():
2525
yield assert_raises, ImportError, m.ComputeMeshWarp
2626
else:
2727
in_surf = example_data('surf01.vtk')
@@ -49,23 +49,20 @@ def test_trans_distances():
4949
yield assert_raises, ImportError, m.ComputeMeshWarp
5050
else:
5151
from nipype.algorithms.mesh import tvtk
52+
from tvtk.common import is_old_pipeline as vtk_old
53+
from tvtk.common import configure_input_data
5254
in_surf = example_data('surf01.vtk')
5355
warped_surf = os.path.join(tempdir, 'warped.vtk')
5456

5557
inc = np.array([0.7, 0.3, -0.2])
5658

5759
r1 = tvtk.PolyDataReader(file_name=in_surf)
58-
vtk1 = r1.output
60+
vtk1 = r1.output if vtk_old() else r1.get_output()
5961
r1.update()
6062
vtk1.points = np.array(vtk1.points) + inc
6163

6264
writer = tvtk.PolyDataWriter(file_name=warped_surf)
63-
64-
if m.Info.vtk_version() < 6:
65-
writer.set_input(vtk1)
66-
else:
67-
writer.set_input_data_object(vtk1)
68-
65+
configure_input_data(writer, vtk1)
6966
writer.write()
7067

7168
dist = m.ComputeMeshWarp()

nipype/interfaces/fsl/utils.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1900,12 +1900,12 @@ def _vtk_to_coords(self, in_file, out_file=None):
19001900
raise ImportError('TVTK is required and tvtk package was not found')
19011901

19021902
from ...algorithms.mesh import tvtk
1903+
from tvtk.common import is_old_pipeline as vtk_old
19031904

1904-
vtk_major = VTKInfo.vtk_version()[0]
19051905
reader = tvtk.PolyDataReader(file_name=in_file + '.vtk')
19061906
reader.update()
19071907

1908-
mesh = reader.output if vtk_major < 6 else reader.get_output()
1908+
mesh = reader.output if vtk_old() else reader.get_output()
19091909
points = mesh.points
19101910

19111911
if out_file is None:
@@ -1921,19 +1921,17 @@ def _coords_to_vtk(self, points, out_file):
19211921
raise ImportError('TVTK is required and tvtk package was not found')
19221922

19231923
from ...algorithms.mesh import tvtk
1924+
from tvtk.common import is_old_pipeline as vtk_old
1925+
from tvtk.common import configure_input_data
19241926

1925-
vtk_major = VTKInfo.vtk_version()[0]
19261927
reader = tvtk.PolyDataReader(file_name=self.inputs.in_file)
19271928
reader.update()
19281929

1929-
mesh = reader.output if vtk_major < 6 else reader.get_output()
1930+
mesh = reader.output if vtk_old() else reader.get_output()
19301931
mesh.points = points
19311932

19321933
writer = tvtk.PolyDataWriter(file_name=out_file)
1933-
if vtk_major < 6:
1934-
writer.input = mesh
1935-
else:
1936-
writer.set_input_data_object(mesh)
1934+
configure_input_data(writer, mesh)
19371935
writer.write()
19381936

19391937
def _trk_to_coords(self, in_file, out_file=None):

0 commit comments

Comments
 (0)