Skip to content

Commit 9339e5f

Browse files
committed
encapsulated VTK Info in class
1 parent 5396852 commit 9339e5f

File tree

3 files changed

+56
-54
lines changed

3 files changed

+56
-54
lines changed

nipype/algorithms/mesh.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -44,33 +44,47 @@
4444
del os.environ['ETS_TOOLKIT']
4545

4646

47-
def no_tvtk():
48-
global have_tvtk
49-
return not have_tvtk
47+
class Info(object):
48+
""" Handle VTK version information """
49+
_vtk_version = None
50+
51+
@staticmethod
52+
def vtk_version():
53+
""" Get VTK version """
54+
if not Info.no_tvtk():
55+
return None
56+
57+
if Info._vtk_version is None:
58+
try:
59+
from tvtk.tvtk_classes.vtk_version import vtk_build_version
60+
vsplits = vtk_build_version.split('.')
61+
Info._vtk_version = tuple([int(vsplits[0]), int(vsplits[1])] + vsplits[2:])
62+
except ImportError:
63+
iflogger.warning(
64+
'VTK version-major inspection using tvtk failed, assuming VTK == 4.0.')
65+
Info._vtk_version = (4, 0)
66+
67+
return Info._vtk_version
68+
69+
@staticmethod
70+
def no_tvtk():
71+
global have_tvtk
72+
return not have_tvtk
5073

5174

5275
class TVTKBaseInterface(BaseInterface):
5376

5477
""" A base class for interfaces using VTK """
5578

5679
_redirect_x = True
57-
_vtk_version = (4, 0, 0)
5880

5981
def __init__(self, **inputs):
60-
if no_tvtk():
82+
if Info.no_tvtk():
6183
raise ImportError('This interface requires tvtk to run.')
62-
63-
try:
64-
from tvtk.tvtk_classes.vtk_version import vtk_build_version
65-
vsplits = vtk_build_version.split('.')
66-
self._vtk_version = tuple([int(vsplits[0]), int(vsplits[1])] + vsplits[2:])
67-
except ImportError:
68-
iflogger.warning(
69-
'VTK version-major inspection using tvtk failed, assuming VTK == 4.0.')
7084
super(TVTKBaseInterface, self).__init__(**inputs)
7185

72-
def version(self):
73-
return self._vtk_version
86+
def vtk_version(self):
87+
return Info.vtk_version()
7488

7589

7690
class WarpPointsInputSpec(BaseInterfaceInputSpec):
@@ -160,7 +174,7 @@ def _run_interface(self, runtime):
160174
newpoints = [p + d for p, d in zip(points, disps)]
161175
mesh.points = newpoints
162176
w = tvtk.PolyDataWriter()
163-
if self.version()[0] < 6:
177+
if self.vtk_version()[0] < 6:
164178
w.input = mesh
165179
else:
166180
w.set_input_data_object(mesh)
@@ -292,7 +306,7 @@ def _run_interface(self, runtime):
292306
writer = tvtk.PolyDataWriter(
293307
file_name=op.abspath(self.inputs.out_warp))
294308

295-
if self.version()[0] < 6:
309+
if self.vtk_version()[0] < 6:
296310
writer.input = out_mesh
297311
else:
298312
writer.set_input_data_object(out_mesh)
@@ -413,7 +427,7 @@ def _run_interface(self, runtime):
413427
vtk1.point_data.vectors = warping
414428
writer = tvtk.PolyDataWriter(
415429
file_name=op.abspath(self.inputs.out_warp))
416-
if self.version()[0] < 6:
430+
if self.vtk_version()[0] < 6:
417431
writer.input = vtk1
418432
else:
419433
writer.set_input_data_object(vtk1)
@@ -424,7 +438,7 @@ def _run_interface(self, runtime):
424438
writer = tvtk.PolyDataWriter(
425439
file_name=op.abspath(self.inputs.out_file))
426440

427-
if self.version()[0] < 6:
441+
if self.vtk_version()[0] < 6:
428442
writer.input = vtk1
429443
else:
430444
writer.set_input_data_object(vtk1)

nipype/algorithms/tests/test_mesh_ops.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_trans_distances():
4545
curdir = os.getcwd()
4646
os.chdir(tempdir)
4747

48-
if m.no_tvtk():
48+
if m.Info.no_tvtk():
4949
yield assert_raises, ImportError, m.ComputeMeshWarp
5050
else:
5151
from nipype.algorithms.mesh import tvtk
@@ -60,7 +60,12 @@ def test_trans_distances():
6060
vtk1.points = np.array(vtk1.points) + inc
6161

6262
writer = tvtk.PolyDataWriter(file_name=warped_surf)
63-
writer.set_input_data(vtk1)
63+
64+
if m.Info.vtk_version() < 6:
65+
writer.set_input(vtk1)
66+
else:
67+
writer.set_input_data_object(vtk1)
68+
6469
writer.write()
6570

6671
dist = m.ComputeMeshWarp()
@@ -82,7 +87,7 @@ def test_warppoints():
8287
curdir = os.getcwd()
8388
os.chdir(tempdir)
8489

85-
if m.no_tvtk():
90+
if m.Info.no_tvtk():
8691
yield assert_raises, ImportError, m.WarpPoints
8792

8893
# TODO: include regression tests for when tvtk is installed
@@ -96,7 +101,7 @@ def test_meshwarpmaths():
96101
curdir = os.getcwd()
97102
os.chdir(tempdir)
98103

99-
if m.no_tvtk():
104+
if m.Info.no_tvtk():
100105
yield assert_raises, ImportError, m.MeshWarpMaths
101106

102107
# TODO: include regression tests for when tvtk is installed

nipype/interfaces/fsl/utils.py

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from ...utils.filemanip import (load_json, save_json, split_filename,
3535
fname_presuffix, copyfile)
3636

37+
from ...algorithms.mesh import Info as VTKInfo
38+
3739
warn = warnings.warn
3840

3941

@@ -1872,8 +1874,8 @@ def __init__(self, command=None, **inputs):
18721874
def _format_arg(self, name, trait_spec, value):
18731875
if name == 'out_file':
18741876
return ''
1875-
else:
1876-
return super(WarpPoints, self)._format_arg(name, trait_spec, value)
1877+
1878+
return super(WarpPoints, self)._format_arg(name, trait_spec, value)
18771879

18781880
def _parse_inputs(self, skip=None):
18791881
fname, ext = op.splitext(self.inputs.in_coords)
@@ -1894,21 +1896,12 @@ def _parse_inputs(self, skip=None):
18941896

18951897
def _vtk_to_coords(self, in_file, out_file=None):
18961898
# Ensure that tvtk is loaded with the appropriate ETS_TOOLKIT env var
1897-
old_ets = os.getenv('ETS_TOOLKIT')
1898-
os.environ['ETS_TOOLKIT'] = 'null'
1899-
try:
1900-
from tvtk.api import tvtk
1901-
from tvtk.tvtk_classes.vtk_version import vtk_build_version
1902-
except ImportError:
1903-
vtk_build_version = None
1904-
raise ImportError('This interface requires tvtk to run.')
1905-
finally:
1906-
if old_ets is not None:
1907-
os.environ['ETS_TOOLKIT'] = old_ets
1908-
else:
1909-
del os.environ['ETS_TOOLKIT']
1899+
if VTKInfo.no_tvtk():
1900+
raise ImportError('TVTK is required and tvtk package was not found')
19101901

1911-
vtk_major = int(vtk_build_version[0])
1902+
from ...algorithms.mesh import tvtk
1903+
1904+
vtk_major = VTKInfo.vtk_version()[0]
19121905
reader = tvtk.PolyDataReader(file_name=in_file + '.vtk')
19131906
reader.update()
19141907

@@ -1923,23 +1916,13 @@ def _vtk_to_coords(self, in_file, out_file=None):
19231916

19241917
def _coords_to_vtk(self, points, out_file):
19251918
import os
1926-
19271919
# Ensure that tvtk is loaded with the appropriate ETS_TOOLKIT env var
1928-
old_ets = os.getenv('ETS_TOOLKIT')
1929-
os.environ['ETS_TOOLKIT'] = 'null'
1930-
try:
1931-
from tvtk.api import tvtk
1932-
from tvtk.tvtk_classes.vtk_version import vtk_build_version
1933-
except ImportError:
1934-
vtk_build_version = None
1935-
raise ImportError('This interface requires tvtk to run.')
1936-
finally:
1937-
if old_ets is not None:
1938-
os.environ['ETS_TOOLKIT'] = old_ets
1939-
else:
1940-
del os.environ['ETS_TOOLKIT']
1920+
if VTKInfo.no_tvtk():
1921+
raise ImportError('TVTK is required and tvtk package was not found')
1922+
1923+
from ...algorithms.mesh import tvtk
19411924

1942-
vtk_major = int(vtk_build_version[0])
1925+
vtk_major = VTKInfo.vtk_version()[0]
19431926
reader = tvtk.PolyDataReader(file_name=self.inputs.in_file)
19441927
reader.update()
19451928

0 commit comments

Comments
 (0)