Skip to content

Commit 532611b

Browse files
committed
now vtk plots and saves figures with matplotlib
1 parent 7772091 commit 532611b

File tree

2 files changed

+47
-34
lines changed

2 files changed

+47
-34
lines changed

pygem/vtkhandler.py

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Utilities for reading and writing different CAD files.
33
"""
44
import numpy as np
5+
import matplotlib.pyplot as plt
6+
import mpl_toolkits.mplot3d as a3
57
import vtk
68
import pygem.filehandler as fh
79

@@ -47,7 +49,7 @@ def parse(self, filename):
4749
mesh_points = np.zeros([n_points, 3])
4850

4951
for i in range(n_points):
50-
mesh_points[i, 0], mesh_points[i, 1], mesh_points[i, 2] = data.GetPoint(i)
52+
mesh_points[i][0], mesh_points[i][1], mesh_points[i][2] = data.GetPoint(i)
5153

5254
return mesh_points
5355

@@ -95,11 +97,13 @@ def write(self, mesh_points, filename):
9597
writer.Write()
9698

9799

98-
def plot(self, plot_file=None):
100+
def plot(self, plot_file=None, save_fig=False):
99101
"""
100102
Method to plot an stl file. If `plot_file` is not given it plots `self.infile`.
101103
102104
:param string plot_file: the stl filename you want to plot.
105+
:param bool save_fig: a flag to save the figure in png or not. If True the
106+
plot is not shown.
103107
"""
104108
if plot_file is None:
105109
plot_file = self.infile
@@ -109,36 +113,37 @@ def plot(self, plot_file=None):
109113
# Read the source file.
110114
reader = vtk.vtkUnstructuredGridReader()
111115
reader.SetFileName(plot_file)
112-
reader.Update() # Needed because of GetScalarRange
113-
output = reader.GetOutput()
114-
scalar_range = output.GetScalarRange()
115-
116-
# Create the mapper that corresponds the objects of the vtk file
117-
# into graphics elements
118-
mapper = vtk.vtkDataSetMapper()
119-
if vtk.VTK_MAJOR_VERSION <= 5:
120-
mapper.SetInput(output)
116+
reader.Update()
117+
118+
data = reader.GetOutput()
119+
points = data.GetPoints()
120+
ncells = data.GetCells().GetNumberOfCells()
121+
122+
# for each cell it contains the indeces of the points that define the cell
123+
cells = np.zeros((ncells, 3))
124+
125+
for i in range(0, ncells):
126+
for j in range(0, 3):
127+
cells[i][j] = data.GetCell(i).GetPointId(j)
128+
129+
figure = plt.figure()
130+
axes = a3.Axes3D(figure)
131+
vtx = np.zeros((ncells, 3, 3))
132+
for i in range(0, ncells):
133+
for j in range(0, 3):
134+
vtx[i][j][0], vtx[i][j][1], vtx[i][j][2] = points.GetPoint(int(cells[i][j]))
135+
tri = a3.art3d.Poly3DCollection([vtx[i]])
136+
tri.set_color('b')
137+
tri.set_edgecolor('k')
138+
axes.add_collection3d(tri)
139+
140+
scale = vtx.flatten(-1)
141+
axes.auto_scale_xyz(scale, scale, scale)
142+
143+
# Show the plot to the screen
144+
if not save_fig:
145+
plt.show()
121146
else:
122-
mapper.SetInputData(output)
123-
mapper.SetScalarRange(scalar_range)
124-
125-
# Create the Actor
126-
actor = vtk.vtkActor()
127-
actor.SetMapper(mapper)
128-
129-
# Create the Renderer
130-
renderer = vtk.vtkRenderer()
131-
renderer.AddActor(actor)
132-
renderer.SetBackground(20, 20, 20) # Set background color (white is 1, 1, 1)
133-
134-
# Create the RendererWindow
135-
renderer_window = vtk.vtkRenderWindow()
136-
renderer_window.AddRenderer(renderer)
137-
138-
# Create the RendererWindowInteractor and display the vtk_file
139-
interactor = vtk.vtkRenderWindowInteractor()
140-
interactor.SetRenderWindow(renderer_window)
141-
interactor.Initialize()
142-
interactor.Start()
143-
144-
147+
figure.savefig(plot_file.split('.')[0] + '.png')
148+
149+

tests/test_vtkhandler.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,11 @@ def test_vtk_plot_failing_outfile_type(self):
140140
with self.assertRaises(TypeError):
141141
vtk_handler.plot(plot_file=1.1)
142142

143+
144+
def test_vtk_plot_save_fig(self):
145+
vtk_handler = vh.VtkHandler()
146+
mesh_points = vtk_handler.parse('tests/test_datasets/test_red_blood_cell.vtk')
147+
vtk_handler.plot(save_fig=True)
148+
self.assertTrue(os.path.isfile('tests/test_datasets/test_red_blood_cell.png'))
149+
os.remove('tests/test_datasets/test_red_blood_cell.png')
150+

0 commit comments

Comments
 (0)