Skip to content

Commit 9cafde5

Browse files
committed
refactored vtk/tvtk use and ETSConfigToolkit
1 parent b7ba47f commit 9cafde5

File tree

4 files changed

+113
-97
lines changed

4 files changed

+113
-97
lines changed

nipype/algorithms/mesh.py

Lines changed: 14 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from __future__ import division
1515
from builtins import zip
1616

17-
import os
1817
import os.path as op
1918
from warnings import warn
2019

@@ -25,52 +24,10 @@
2524
from ..external.six import string_types
2625
from ..interfaces.base import (BaseInterface, traits, TraitedSpec, File,
2726
BaseInterfaceInputSpec)
28-
27+
from ..interfaces.vtkbase import tvtk
28+
from ..interfaces import vtkbase as VTKInfo
2929
iflogger = logging.getLogger('interface')
3030

31-
# Ensure that tvtk is loaded with the appropriate ETS_TOOLKIT env var
32-
old_ets = os.getenv('ETS_TOOLKIT')
33-
os.environ['ETS_TOOLKIT'] = 'null'
34-
have_tvtk = False
35-
try:
36-
from tvtk.api import tvtk
37-
have_tvtk = True
38-
except ImportError:
39-
iflogger.warning('tvtk wasn\'t found')
40-
finally:
41-
if old_ets is not None:
42-
os.environ['ETS_TOOLKIT'] = old_ets
43-
else:
44-
del os.environ['ETS_TOOLKIT']
45-
46-
47-
class Info(object):
48-
""" Handle VTK version information """
49-
_vtk_version = None
50-
51-
@staticmethod
52-
def vtk_version():
53-
""" Get VTK version """
54-
if not Info.no_tvtk():
55-
return None
56-
57-
if Info._vtk_version is None:
58-
try:
59-
from tvtk.tvtk_classes.vtk_version import vtk_build_version
60-
vsplits = vtk_build_version.split('.')
61-
Info._vtk_version = tuple([int(vsplits[0]), int(vsplits[1])] + vsplits[2:])
62-
except ImportError:
63-
iflogger.warning(
64-
'VTK version-major inspection using tvtk failed, assuming VTK == 4.0.')
65-
Info._vtk_version = (4, 0)
66-
67-
return Info._vtk_version
68-
69-
@staticmethod
70-
def no_tvtk():
71-
global have_tvtk
72-
return not have_tvtk
73-
7431

7532
class TVTKBaseInterface(BaseInterface):
7633

@@ -79,13 +36,10 @@ class TVTKBaseInterface(BaseInterface):
7936
_redirect_x = True
8037

8138
def __init__(self, **inputs):
82-
if Info.no_tvtk():
39+
if VTKInfo.no_tvtk():
8340
raise ImportError('This interface requires tvtk to run.')
8441
super(TVTKBaseInterface, self).__init__(**inputs)
8542

86-
def vtk_version(self):
87-
return Info.vtk_version()
88-
8943

9044
class WarpPointsInputSpec(BaseInterfaceInputSpec):
9145
points = File(exists=True, mandatory=True,
@@ -144,17 +98,15 @@ def _run_interface(self, runtime):
14498
import nibabel as nb
14599
import numpy as np
146100
from scipy import ndimage
147-
from tvtk.common import configure_input_data
148-
from tvtk.common import is_old_pipeline as vtk_old
149101

150102
r = tvtk.PolyDataReader(file_name=self.inputs.points)
151103
r.update()
152-
mesh = r.output if vtk_old() else r.get_output()
104+
mesh = VTKInfo.vtk_output(r)
153105
points = np.array(mesh.points)
154106
warp_dims = nb.funcs.four_to_three(nb.load(self.inputs.warp))
155107

156108
affine = warp_dims[0].affine
157-
voxsize = warp_dims[0].header.get_zooms()
109+
# voxsize = warp_dims[0].header.get_zooms()
158110
vox2ras = affine[0:3, 0:3]
159111
ras2vox = np.linalg.inv(vox2ras)
160112
origin = affine[0:3, 3]
@@ -176,7 +128,7 @@ def _run_interface(self, runtime):
176128
newpoints = [p + d for p, d in zip(points, disps)]
177129
mesh.points = newpoints
178130
w = tvtk.PolyDataWriter()
179-
configure_input_data(w, mesh)
131+
VTKInfo.configure_input_data(w, mesh)
180132
w.file_name = self._gen_fname(self.inputs.points, suffix='warped', ext='.vtk')
181133
w.write()
182134
return runtime
@@ -253,13 +205,10 @@ def _triangle_area(self, A, B, C):
253205
return area
254206

255207
def _run_interface(self, runtime):
256-
from tvtk.common import configure_input_data
257-
from tvtk.common import is_old_pipeline as vtk_old
258-
259208
r1 = tvtk.PolyDataReader(file_name=self.inputs.surface1)
260209
r2 = tvtk.PolyDataReader(file_name=self.inputs.surface2)
261-
vtk1 = r1.output if vtk_old() else r1.get_output()
262-
vtk2 = r2.output if vtk_old() else r2.get_output()
210+
vtk1 = VTKInfo.vtk_output(r1)
211+
vtk2 = VTKInfo.vtk_output(r2)
263212
r1.update()
264213
r2.update()
265214
assert(len(vtk1.points) == len(vtk2.points))
@@ -303,7 +252,7 @@ def _run_interface(self, runtime):
303252
out_mesh.point_data.vectors.name = 'warpings'
304253
writer = tvtk.PolyDataWriter(
305254
file_name=op.abspath(self.inputs.out_warp))
306-
configure_input_data(writer, out_mesh)
255+
VTKInfo.configure_input_data(writer, out_mesh)
307256
writer.write()
308257

309258
self._distance = np.average(errvector, weights=weights)
@@ -372,11 +321,8 @@ class MeshWarpMaths(TVTKBaseInterface):
372321
output_spec = MeshWarpMathsOutputSpec
373322

374323
def _run_interface(self, runtime):
375-
from tvtk.common import configure_input_data
376-
from tvtk.common import is_old_pipeline as vtk_old
377-
378324
r1 = tvtk.PolyDataReader(file_name=self.inputs.in_surf)
379-
vtk1 = r1.output if vtk_old() else r1.get_output()
325+
vtk1 = VTKInfo.vtk_output(r1)
380326
r1.update()
381327
points1 = np.array(vtk1.points)
382328

@@ -388,7 +334,7 @@ def _run_interface(self, runtime):
388334

389335
if isinstance(operator, string_types):
390336
r2 = tvtk.PolyDataReader(file_name=self.inputs.surface2)
391-
vtk2 = r2.output if vtk_old() else r2.get_output()
337+
vtk2 = VTKInfo.vtk_output(r2)
392338
r2.update()
393339
assert(len(points1) == len(vtk2.points))
394340

@@ -399,7 +345,7 @@ def _run_interface(self, runtime):
399345

400346
if opfield is None:
401347
raise RuntimeError(
402-
('No operator values found in operator file'))
348+
'No operator values found in operator file')
403349

404350
opfield = np.array(opfield)
405351

@@ -422,13 +368,13 @@ def _run_interface(self, runtime):
422368

423369
vtk1.point_data.vectors = warping
424370
writer = tvtk.PolyDataWriter(file_name=op.abspath(self.inputs.out_warp))
425-
configure_input_data(writer, vtk1)
371+
VTKInfo.configure_input_data(writer, vtk1)
426372
writer.write()
427373

428374
vtk1.point_data.vectors = None
429375
vtk1.points = points1 + warping
430376
writer = tvtk.PolyDataWriter(file_name=op.abspath(self.inputs.out_file))
431-
configure_input_data(writer, vtk1)
377+
VTKInfo.configure_input_data(writer, vtk1)
432378
writer.write()
433379
return runtime
434380

nipype/algorithms/tests/test_mesh_ops.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,17 @@
88

99
from nipype.testing import (assert_equal, assert_raises, skipif,
1010
assert_almost_equal, example_data)
11-
1211
import numpy as np
13-
1412
from nipype.algorithms import mesh as m
15-
16-
import platform
13+
from ...interfaces import vtkbase as VTKInfo
1714

1815

1916
def test_ident_distances():
2017
tempdir = mkdtemp()
2118
curdir = os.getcwd()
2219
os.chdir(tempdir)
2320

24-
if m.Info.no_tvtk():
21+
if VTKInfo.no_tvtk():
2522
yield assert_raises, ImportError, m.ComputeMeshWarp
2623
else:
2724
in_surf = example_data('surf01.vtk')
@@ -45,24 +42,23 @@ def test_trans_distances():
4542
curdir = os.getcwd()
4643
os.chdir(tempdir)
4744

48-
if m.Info.no_tvtk():
45+
if VTKInfo.no_tvtk():
4946
yield assert_raises, ImportError, m.ComputeMeshWarp
5047
else:
51-
from nipype.algorithms.mesh import tvtk
52-
from tvtk.common import is_old_pipeline as vtk_old
53-
from tvtk.common import configure_input_data
48+
from ...interfaces.vtkbase import tvtk
49+
5450
in_surf = example_data('surf01.vtk')
5551
warped_surf = os.path.join(tempdir, 'warped.vtk')
5652

5753
inc = np.array([0.7, 0.3, -0.2])
5854

5955
r1 = tvtk.PolyDataReader(file_name=in_surf)
60-
vtk1 = r1.output if vtk_old() else r1.get_output()
56+
vtk1 = VTKInfo.vtk_output(r1)
6157
r1.update()
6258
vtk1.points = np.array(vtk1.points) + inc
6359

6460
writer = tvtk.PolyDataWriter(file_name=warped_surf)
65-
configure_input_data(writer, vtk1)
61+
VTKInfo.configure_input_data(writer, vtk1)
6662
writer.write()
6763

6864
dist = m.ComputeMeshWarp()
@@ -84,7 +80,7 @@ def test_warppoints():
8480
curdir = os.getcwd()
8581
os.chdir(tempdir)
8682

87-
if m.Info.no_tvtk():
83+
if VTKInfo.no_tvtk():
8884
yield assert_raises, ImportError, m.WarpPoints
8985

9086
# TODO: include regression tests for when tvtk is installed
@@ -98,7 +94,7 @@ def test_meshwarpmaths():
9894
curdir = os.getcwd()
9995
os.chdir(tempdir)
10096

101-
if m.Info.no_tvtk():
97+
if VTKInfo.no_tvtk():
10298
yield assert_raises, ImportError, m.MeshWarpMaths
10399

104100
# TODO: include regression tests for when tvtk is installed

nipype/interfaces/fsl/utils.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@
3434
from ...utils.filemanip import (load_json, save_json, split_filename,
3535
fname_presuffix, copyfile)
3636

37-
from ...algorithms.mesh import Info as VTKInfo
38-
3937
warn = warnings.warn
4038

4139

@@ -1895,17 +1893,15 @@ def _parse_inputs(self, skip=None):
18951893
return first_args + [second_args]
18961894

18971895
def _vtk_to_coords(self, in_file, out_file=None):
1898-
# Ensure that tvtk is loaded with the appropriate ETS_TOOLKIT env var
1896+
from ..vtkbase import tvtk
1897+
from ...interfaces import vtkbase as VTKInfo
1898+
18991899
if VTKInfo.no_tvtk():
19001900
raise ImportError('TVTK is required and tvtk package was not found')
19011901

1902-
from ...algorithms.mesh import tvtk
1903-
from tvtk.common import is_old_pipeline as vtk_old
1904-
19051902
reader = tvtk.PolyDataReader(file_name=in_file + '.vtk')
19061903
reader.update()
1907-
1908-
mesh = reader.output if vtk_old() else reader.get_output()
1904+
mesh = VTKInfo.vtk_output(reader)
19091905
points = mesh.points
19101906

19111907
if out_file is None:
@@ -1915,23 +1911,20 @@ def _vtk_to_coords(self, in_file, out_file=None):
19151911
return out_file
19161912

19171913
def _coords_to_vtk(self, points, out_file):
1918-
import os
1919-
# Ensure that tvtk is loaded with the appropriate ETS_TOOLKIT env var
1914+
from ..vtkbase import tvtk
1915+
from ...interfaces import vtkbase as VTKInfo
1916+
19201917
if VTKInfo.no_tvtk():
19211918
raise ImportError('TVTK is required and tvtk package was not found')
19221919

1923-
from ...algorithms.mesh import tvtk
1924-
from tvtk.common import is_old_pipeline as vtk_old
1925-
from tvtk.common import configure_input_data
1926-
19271920
reader = tvtk.PolyDataReader(file_name=self.inputs.in_file)
19281921
reader.update()
19291922

1930-
mesh = reader.output if vtk_old() else reader.get_output()
1923+
mesh = VTKInfo.vtk_output(reader)
19311924
mesh.points = points
19321925

19331926
writer = tvtk.PolyDataWriter(file_name=out_file)
1934-
configure_input_data(writer, mesh)
1927+
VTKInfo.configure_input_data(writer, mesh)
19351928
writer.write()
19361929

19371930
def _trk_to_coords(self, in_file, out_file=None):

nipype/interfaces/vtkbase.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
"""
4+
vtkbase provides some helpers to use VTK through the tvtk package (mayavi)
5+
6+
Code using tvtk should import it through this module
7+
"""
8+
9+
import os
10+
from .. import logging
11+
12+
iflogger = logging.getLogger('interface')
13+
14+
# Check that VTK can be imported and get version
15+
_vtk_version = None
16+
try:
17+
import vtk
18+
_vtk_version = (vtk.vtkVersion.GetVTKMajorVersion(),
19+
vtk.vtkVersion.GetVTKMinorVersion())
20+
except ImportError:
21+
iflogger.warning('VTK was not found')
22+
23+
# Ensure that tvtk is loaded with the appropriate ETS_TOOLKIT env var
24+
old_ets = os.getenv('ETS_TOOLKIT')
25+
os.environ['ETS_TOOLKIT'] = 'null'
26+
_have_tvtk = False
27+
try:
28+
from tvtk.api import tvtk
29+
_have_tvtk = True
30+
except ImportError:
31+
iflogger.warning('tvtk wasn\'t found')
32+
tvtk = None
33+
finally:
34+
if old_ets is not None:
35+
os.environ['ETS_TOOLKIT'] = old_ets
36+
else:
37+
del os.environ['ETS_TOOLKIT']
38+
39+
40+
def vtk_version():
41+
""" Get VTK version """
42+
global _vtk_version
43+
return _vtk_version
44+
45+
46+
def no_vtk():
47+
""" Checks if VTK is installed and the python wrapper is functional """
48+
global _vtk_version
49+
return _vtk_version is None
50+
51+
52+
def no_tvtk():
53+
""" Checks if tvtk was found """
54+
global _have_tvtk
55+
return not _have_tvtk
56+
57+
58+
def vtk_old():
59+
""" Checks if VTK uses the old-style pipeline (VTK<6.0) """
60+
global _vtk_version
61+
if _vtk_version is None:
62+
raise RuntimeException('VTK is not correctly installed.')
63+
return _vtk_version[0] < 6
64+
65+
66+
def configure_input_data(obj, data):
67+
"""
68+
Configure the input data for vtk pipeline object obj.
69+
Copied from latest version of mayavi
70+
"""
71+
if vtk_old():
72+
obj.input = data
73+
else:
74+
obj.set_input_data(data)
75+
76+
77+
def vtk_output(obj):
78+
""" Configure the input data for vtk pipeline object obj."""
79+
if vtk_old():
80+
return obj.output
81+
return obj.get_output()

0 commit comments

Comments
 (0)