Skip to content

Commit 3bab6de

Browse files
committed
IDW refactoring
1 parent 1daf6f0 commit 3bab6de

File tree

5 files changed

+203
-123
lines changed

5 files changed

+203
-123
lines changed

pygem/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def get_current_year():
2424
from .deformation import Deformation
2525
from .ffd import FFD
2626
from .rbf import RBF
27+
from .idw import IDW
2728
from .rbf_factory import RBFFactory
2829
#from .radial import RBF
2930
#from .idw import IDW

pygem/idw.py

Lines changed: 128 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -36,70 +36,83 @@
3636
:math:`\\mathrm{x}_i` and :math:`p` is a power parameter, typically equal to
3737
2.
3838
"""
39+
import os
3940
import numpy as np
41+
try:
42+
import configparser as configparser
43+
except ImportError:
44+
import ConfigParser as configparser
45+
4046
from scipy.spatial.distance import cdist
47+
from .deformation import Deformation
4148

4249

43-
class IDW(object):
50+
class IDW(Deformation):
4451
"""
45-
Class that handles the IDW technique.
46-
47-
:param idw_parameters: the parameters of the IDW
48-
:type idw_parameters: :class:`IDWParameters`
49-
:param numpy.ndarray original_mesh_points: coordinates of the original
50-
points of the mesh.
51-
52-
:cvar parameters: the parameters of the IDW.
53-
:vartype parameters: :class:`~pygem.params_idw.IDWParameters`
54-
:cvar numpy.ndarray original_mesh_points: coordinates of the original
55-
points of the mesh.
56-
:cvar numpy.ndarray modified_mesh_points: coordinates of the deformed
57-
points of the mesh.
52+
Class that perform the Inverse Distance Weighting (IDW).
53+
54+
:cvar int power: the power parameter. The default value is 2.
55+
:cvar numpy.ndarray original_control_points: it is an
56+
`n_control_points`-by-3 array with the coordinates of the original
57+
interpolation control points before the deformation. The default is the
58+
vertices of the unit cube.
59+
:cvar numpy.ndarray deformed_control_points: it is an
60+
`n_control_points`-by-3 array with the coordinates of the interpolation
61+
control points after the deformation. The default is the vertices of
62+
the unit cube.
5863
5964
:Example:
6065
61-
>>> from pygem.idw import IDW
62-
>>> from pygem.params_idw import IDWParameters
66+
>>> from pygem import IDW
6367
>>> import numpy as np
64-
>>> params = IDWParameters()
65-
>>> params.read_parameters('tests/test_datasets/parameters_idw_cube.prm')
6668
>>> nx, ny, nz = (20, 20, 20)
6769
>>> mesh = np.zeros((nx * ny * nz, 3))
6870
>>> xv = np.linspace(0, 1, nx)
6971
>>> yv = np.linspace(0, 1, ny)
7072
>>> zv = np.linspace(0, 1, nz)
7173
>>> z, y, x = np.meshgrid(zv, yv, xv)
72-
>>> mesh = np.array([x.ravel(), y.ravel(), z.ravel()])
73-
>>> original_mesh_points = mesh.T
74-
>>> idw = IDW(rbf_parameters, original_mesh_points)
75-
>>> idw.perform()
76-
>>> new_mesh_points = idw.modified_mesh_points
74+
>>> mesh_points = np.array([x.ravel(), y.ravel(), z.ravel()])
75+
>>> idw = IDW()
76+
>>> idw.read_parameters('tests/test_datasets/parameters_idw_cube.prm')
77+
>>> new_mesh_points = idw(mesh_points.T)
7778
"""
7879

79-
def __init__(self, idw_parameters, original_mesh_points):
80-
self.parameters = idw_parameters
81-
self.original_mesh_points = original_mesh_points
82-
self.modified_mesh_points = None
80+
def __init__(self,
81+
original_control_points=None,
82+
deformed_control_points=None,
83+
power=2):
84+
85+
if original_control_points is None:
86+
self.original_control_points = np.array(
87+
[[0., 0., 0.], [0., 0., 1.], [0., 1., 0.], [1., 0., 0.],
88+
[0., 1., 1.], [1., 0., 1.], [1., 1., 0.], [1., 1., 1.]])
89+
else:
90+
self.original_control_points = original_control_points
91+
92+
if deformed_control_points is None:
93+
self.deformed_control_points = np.array(
94+
[[0., 0., 0.], [0., 0., 1.], [0., 1., 0.], [1., 0., 0.],
95+
[0., 1., 1.], [1., 0., 1.], [1., 1., 0.], [1., 1., 1.]])
96+
else:
97+
self.deformed_control_points = deformed_control_points
98+
99+
self.power = power
83100

84-
def perform(self):
101+
def __call__(self, src_pts):
85102
"""
86103
This method performs the deformation of the mesh points. After the
87104
execution it sets `self.modified_mesh_points`.
88105
"""
89106

90107
def distance(u, v):
91-
"""
92-
Norm of u - v
93-
"""
94-
return np.linalg.norm(u - v, ord=self.parameters.power)
108+
""" Norm of u - v """
109+
return np.linalg.norm(u - v, ord=self.power)
95110

96111
# Compute displacement of the control points
97-
displ = (self.parameters.deformed_control_points -
98-
self.parameters.original_control_points)
112+
displ = self.deformed_control_points - self.original_control_points
99113

100114
# Compute the distance between the mesh points and the control points
101-
dist = cdist(self.original_mesh_points,
102-
self.parameters.original_control_points, distance)
115+
dist = cdist(src_pts, self.original_control_points, distance)
103116

104117
# Weights are set as the reciprocal of the distance if the distance is
105118
# not zero, otherwise 1.0 where distance is zero.
@@ -112,4 +125,82 @@ def distance(u, v):
112125
for wi in weights
113126
])
114127

115-
self.modified_mesh_points = self.original_mesh_points + offset
128+
return src_pts + offset
129+
130+
def read_parameters(self, filename):
131+
"""
132+
Reads in the parameters file and fill the self structure.
133+
134+
:param string filename: parameters file to be read in.
135+
"""
136+
if not isinstance(filename, str):
137+
raise TypeError('filename must be a string')
138+
139+
if not os.path.isfile(filename):
140+
raise IOError('filename does not exist')
141+
142+
config = configparser.RawConfigParser()
143+
config.read(filename)
144+
145+
self.power = config.getint('Inverse Distance Weighting', 'power')
146+
147+
ctrl_points = config.get('Control points', 'original control points')
148+
self.original_control_points = np.array(
149+
[line.split() for line in ctrl_points.split('\n')], dtype=float)
150+
151+
defo_points = config.get('Control points', 'deformed control points')
152+
self.deformed_control_points = np.array(
153+
[line.split() for line in defo_points.split('\n')], dtype=float)
154+
155+
def write_parameters(self, filename):
156+
"""
157+
This method writes a parameters file (.prm) called `filename` and fills
158+
it with all the parameters class members.
159+
160+
:param string filename: parameters file to be written out.
161+
"""
162+
if not isinstance(filename, str):
163+
raise TypeError("filename must be a string")
164+
165+
output_string = ""
166+
output_string += "\n[Inverse Distance Weighting]\n"
167+
output_string += "# This section describes the settings of idw.\n\n"
168+
output_string += "# the power parameter\n"
169+
output_string += "power = {}\n".format(self.power)
170+
171+
output_string += "\n\n[Control points]\n"
172+
output_string += "# This section describes the IDW control points.\n\n"
173+
output_string += "# original control points collects the coordinates\n"
174+
output_string += "# of the interpolation control points before the\n"
175+
output_string += "# deformation.\n"
176+
177+
output_string += "original control points: "
178+
output_string += (
179+
' '.join(map(str, self.original_control_points[0])) + "\n")
180+
for points in self.original_control_points[1:]:
181+
output_string += 25 * ' ' + ' '.join(map(str, points)) + "\n"
182+
output_string += "\n"
183+
output_string += "# deformed control points collects the coordinates\n"
184+
output_string += "# of the interpolation control points after the\n"
185+
output_string += "# deformation.\n"
186+
output_string += "deformed control points: "
187+
output_string += (
188+
' '.join(map(str, self.original_control_points[0])) + "\n")
189+
for points in self.deformed_control_points[1:]:
190+
output_string += 25 * ' ' + ' '.join(map(str, points)) + "\n"
191+
192+
with open(filename, 'w') as f:
193+
f.write(output_string)
194+
195+
def __str__(self):
196+
"""
197+
This method prints all the IDW parameters on the screen. Its purpose is
198+
for debugging.
199+
"""
200+
string = ''
201+
string += 'p = {}\n'.format(self.power)
202+
string += '\noriginal_control_points =\n'
203+
string += '{}\n'.format(self.original_control_points)
204+
string += '\ndeformed_control_points =\n'
205+
string += '{}\n'.format(self.deformed_control_points)
206+
return string

pygem/rbf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,11 @@
6060

6161
from scipy.spatial.distance import cdist
6262

63+
from .deformation import Deformation
6364
from .rbf_factory import RBFFactory
6465

6566

66-
class RBF(object):
67+
class RBF(Deformation):
6768
"""
6869
Class that handles the Radial Basis Functions interpolation on the mesh
6970
points.

tests/test_idw.py

Lines changed: 72 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1+
import os
2+
import filecmp
3+
import numpy as np
14
from unittest import TestCase
2-
import unittest
35
from pygem import IDW
4-
from pygem import IDWParameters
5-
import numpy as np
6-
76

87
class TestIDW(TestCase):
98
def get_cube_mesh_points(self):
@@ -19,20 +18,76 @@ def get_cube_mesh_points(self):
1918
return original_mesh_points
2019

2120
def test_idw(self):
22-
params = IDWParameters()
23-
params.read_parameters('tests/test_datasets/parameters_idw_default.prm')
24-
idw = IDW(params, self.get_cube_mesh_points())
21+
idw = IDW()
2522

26-
def test_idw_perform(self):
27-
params = IDWParameters()
28-
params.read_parameters('tests/test_datasets/parameters_idw_default.prm')
29-
IDW(params, self.get_cube_mesh_points()).perform()
23+
def test_idw_call(self):
24+
idw = IDW()
25+
idw.read_parameters('tests/test_datasets/parameters_idw_default.prm')
26+
idw(self.get_cube_mesh_points())
3027

3128
def test_idw_perform_deform(self):
32-
params = IDWParameters()
29+
idw = IDW()
3330
expected_stretch = [1.19541593, 1.36081491, 1.42095073]
34-
params.read_parameters('tests/test_datasets/parameters_idw_deform.prm')
35-
idw = IDW(params, self.get_cube_mesh_points())
36-
idw.perform()
37-
np.testing.assert_array_almost_equal(idw.modified_mesh_points[-3],
38-
expected_stretch)
31+
idw.read_parameters('tests/test_datasets/parameters_idw_deform.prm')
32+
new_pts = idw(self.get_cube_mesh_points())
33+
np.testing.assert_array_almost_equal(new_pts[-3], expected_stretch)
34+
35+
def test_class_members_default_p(self):
36+
idw = IDW()
37+
assert idw.power == 2
38+
39+
def test_class_members_default_original_points(self):
40+
idw = IDW()
41+
cube_vertices = np.array([[0., 0., 0.], [0., 0., 1.], [0., 1., 0.],
42+
[1., 0., 0.], [0., 1., 1.], [1., 0., 1.],
43+
[1., 1., 0.], [1., 1., 1.]])
44+
np.testing.assert_equal(idw.original_control_points, cube_vertices)
45+
46+
def test_class_members_default_deformed_points(self):
47+
idw = IDW()
48+
cube_vertices = np.array([[0., 0., 0.], [0., 0., 1.], [0., 1., 0.],
49+
[1., 0., 0.], [0., 1., 1.], [1., 0., 1.],
50+
[1., 1., 0.], [1., 1., 1.]])
51+
np.testing.assert_equal(idw.deformed_control_points, cube_vertices)
52+
53+
def test_write_parameters_filename_default(self):
54+
params = IDW()
55+
outfilename = 'parameters_rbf.prm'
56+
outfilename_expected = 'tests/test_datasets/parameters_idw_default.prm'
57+
params.write_parameters(outfilename)
58+
self.assertTrue(filecmp.cmp(outfilename, outfilename_expected))
59+
os.remove(outfilename)
60+
61+
def test_write_not_string(self):
62+
params = IDW()
63+
with self.assertRaises(TypeError):
64+
params.write_parameters(5)
65+
66+
def test_read_deformed(self):
67+
params = IDW()
68+
filename = 'tests/test_datasets/parameters_idw_deform.prm'
69+
def_vertices = np.array([[0., 0., 0.], [0., 0., 1.], [0., 1., 0.],
70+
[1., 0., 0.], [0., 1., 1.], [1., 0., 1.],
71+
[1., 1., 0.], [1.5, 1.6, 1.7]])
72+
params.read_parameters(filename)
73+
np.testing.assert_equal(params.deformed_control_points, def_vertices)
74+
75+
def test_read_p(self):
76+
idw = IDW()
77+
filename = 'tests/test_datasets/parameters_idw_deform.prm'
78+
idw.read_parameters(filename)
79+
assert idw.power == 3
80+
81+
def test_read_not_string(self):
82+
idw = IDW()
83+
with self.assertRaises(TypeError):
84+
idw.read_parameters(5)
85+
86+
def test_read_not_real_file(self):
87+
idw = IDW()
88+
with self.assertRaises(IOError):
89+
idw.read_parameters('not_real_file')
90+
91+
def test_print(self):
92+
idw = IDW()
93+
print(idw)

tests/test_idwparams.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

0 commit comments

Comments
 (0)