Skip to content

Commit 33ecc2f

Browse files
authored
Merge pull request #228 from ndem0/extra-param-file
rbf read extra_parameter from file
2 parents 701480a + e0de452 commit 33ecc2f

File tree

4 files changed

+120
-19
lines changed

4 files changed

+120
-19
lines changed

pygem/rbf.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,14 @@ class RBF(Deformation):
108108
transformation.
109109
:cvar float radius: the scaling parameter that affects the shape of the
110110
basis functions.
111-
:cvar dict extra_parameter: the additional parameters that may be passed to
112-
the kernel function.
111+
:cvar dict extra: the additional parameters that may be passed to the
112+
kernel function.
113113
114114
:Example:
115115
116116
>>> from pygem import RBF
117117
>>> import numpy as np
118-
>>> rbf = RBF('gaussian_spline')
118+
>>> rbf = RBF(func='gaussian_spline')
119119
>>> xv = np.linspace(0, 1, 20)
120120
>>> yv = np.linspace(0, 1, 20)
121121
>>> zv = np.linspace(0, 1, 20)
@@ -208,7 +208,7 @@ def _get_weights(self, X, Y):
208208
"""
209209
npts, dim = X.shape
210210
H = np.zeros((npts + 3 + 1, npts + 3 + 1))
211-
H[:npts, :npts] = self.basis(cdist(X, X), self.radius)#, **self.extra)
211+
H[:npts, :npts] = self.basis(cdist(X, X), self.radius, **self.extra)
212212
H[npts, :npts] = 1.0
213213
H[:npts, npts] = 1.0
214214
H[:npts, -3:] = X
@@ -239,8 +239,11 @@ def read_parameters(self, filename='parameters_rbf.prm'):
239239
config = configparser.RawConfigParser()
240240
config.read(filename)
241241

242-
self.basis = config.get('Radial Basis Functions', 'basis function')
243-
self.radius = config.getfloat('Radial Basis Functions', 'radius')
242+
rbf_settings = dict(config.items('Radial Basis Functions'))
243+
244+
self.basis = rbf_settings.pop('basis function')
245+
self.radius = float(rbf_settings.pop('radius'))
246+
self.extra = {k: eval(v) for k, v in rbf_settings.items()}
244247

245248
ctrl_points = config.get('Control points', 'original control points')
246249
lines = ctrl_points.split('\n')
@@ -331,6 +334,7 @@ def __str__(self):
331334
string = ''
332335
string += 'basis function = {}\n'.format(self.basis)
333336
string += 'radius = {}\n'.format(self.radius)
337+
string += 'extra_parameter = {}\n'.format(self.extra)
334338
string += '\noriginal control points =\n'
335339
string += '{}\n'.format(self.original_control_points)
336340
string += '\ndeformed control points =\n'
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
2+
[Radial Basis Functions]
3+
# This section describes the radial basis functions shape.
4+
5+
# basis funtion is the name of the basis functions to use in the transformation. The functions
6+
# implemented so far are: gaussian_spline, multi_quadratic_biharmonic_spline,
7+
# inv_multi_quadratic_biharmonic_spline, thin_plate_spline, beckert_wendland_c2_basis, polyharmonic_spline.
8+
# For a comprehensive list with details see the class RBF.
9+
basis function: polyharmonic_spline
10+
11+
# radius is the scaling parameter r that affects the shape of the basis functions. See the documentation
12+
# of the class RBF for details.
13+
radius: 0.5
14+
15+
# Any additional parameter to pass to the basis function (eg the `k` power for poliharmonic_spline)
16+
k: 4
17+
18+
[Control points]
19+
# This section describes the RBF control points.
20+
21+
# original control points collects the coordinates of the interpolation control points before the deformation.
22+
original control points: 0.0 0.0 0.0
23+
0.0 0.0 1.0
24+
0.0 1.0 0.0
25+
1.0 0.0 0.0
26+
0.0 1.0 1.0
27+
1.0 0.0 1.0
28+
1.0 1.0 0.0
29+
1.0 1.0 1.0
30+
31+
# deformed control points collects the coordinates of the interpolation control points after the deformation.
32+
deformed control points: 0.0 0.0 0.0
33+
0.0 0.0 1.0
34+
0.0 1.0 0.0
35+
1.0 0.0 0.0
36+
0.0 1.0 1.0
37+
1.0 0.0 1.0
38+
1.0 1.0 0.0
39+
1.0 1.0 2.0
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
2+
[Radial Basis Functions]
3+
# This section describes the radial basis functions shape.
4+
5+
# basis funtion is the name of the basis functions to use in the transformation. The functions
6+
# implemented so far are: gaussian_spline, multi_quadratic_biharmonic_spline,
7+
# inv_multi_quadratic_biharmonic_spline, thin_plate_spline, beckert_wendland_c2_basis, polyharmonic_spline.
8+
# For a comprehensive list with details see the class RBF.
9+
basis function: gaussian_spline
10+
11+
# radius is the scaling parameter r that affects the shape of the basis functions. See the documentation
12+
# of the class RBF for details.
13+
radius: 2.0
14+
15+
16+
17+
[Control points]
18+
# This section describes the RBF control points.
19+
20+
# original control points collects the coordinates of the interpolation control points before the deformation.
21+
original control points: 0.0 0.0 0.0
22+
0.0 0.0 1.0
23+
0.0 1.0 0.0
24+
1.0 0.0 0.0
25+
0.0 1.0 1.0
26+
1.0 0.0 1.0
27+
1.0 1.0 0.0
28+
1.0 1.0 1.0
29+
30+
# deformed control points collects the coordinates of the interpolation control points after the deformation.
31+
deformed control points: 0.0 0.0 0.0
32+
0.0 0.0 1.0
33+
0.0 1.0 0.0
34+
1.0 0.0 0.0
35+
0.0 1.0 1.0
36+
1.0 0.0 1.0
37+
1.0 1.0 0.0
38+
1.0 1.0 1.0

tests/test_rbf.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ def test_class_members_default_radius(self):
5151
rbf = RBF()
5252
assert rbf.radius == 0.5
5353

54+
def test_class_members_default_extra(self):
55+
rbf = RBF()
56+
assert rbf.extra == {}
57+
5458
def test_class_members_default_n_control_points(self):
5559
rbf = RBF()
5660
assert rbf.n_control_points == 8
@@ -68,10 +72,20 @@ def test_read_parameters_basis(self):
6872
rbf.read_parameters('tests/test_datasets/parameters_rbf_default.prm')
6973
assert rbf.basis == RBFFactory('gaussian_spline')
7074

75+
def test_read_parameters_basis2(self):
76+
rbf = RBF()
77+
rbf.read_parameters('tests/test_datasets/parameters_rbf_extra.prm')
78+
assert rbf.basis == RBFFactory('polyharmonic_spline')
79+
7180
def test_read_parameters_radius(self):
7281
rbf = RBF()
73-
rbf.read_parameters('tests/test_datasets/parameters_rbf_default.prm')
74-
assert rbf.radius == 0.5
82+
rbf.read_parameters('tests/test_datasets/parameters_rbf_radius.prm')
83+
assert rbf.radius == 2.0
84+
85+
def test_read_extra_parameters(self):
86+
rbf = RBF()
87+
rbf.read_parameters('tests/test_datasets/parameters_rbf_extra.prm')
88+
assert rbf.extra == {'k': 4}
7589

7690
def test_read_parameters_n_control_points(self):
7791
rbf = RBF()
@@ -145,17 +159,23 @@ def test_write_parameters(self):
145159
146160
self.assertTrue(filecmp.cmp(outfilename, outfilename_expected))
147161
#os.remove(outfilename)
148-
149-
def test_read_parameters_filename_default(self):
150-
params = RBF()
151-
params.read_parameters()
152-
outfilename = 'parameters_rbf.prm'
153-
outfilename_expected = 'tests/test_datasets/parameters_rbf_default.prm'
154-
155-
self.assertTrue(filecmp.cmp(outfilename, outfilename_expected))
156-
os.remove(outfilename)
157162
"""
158163

159164
def test_print_info(self):
160-
params = RBF()
161-
print(params)
165+
rbf = RBF()
166+
print(rbf)
167+
168+
def test_call_dummy_transformation(self):
169+
rbf = RBF()
170+
rbf.read_parameters('tests/test_datasets/parameters_rbf_default.prm')
171+
mesh = self.get_cube_mesh_points()
172+
new = rbf(mesh)
173+
np.testing.assert_array_almost_equal(new[17], mesh[17])
174+
175+
def test_call(self):
176+
rbf = RBF()
177+
rbf.read_parameters('tests/test_datasets/parameters_rbf_extra.prm')
178+
mesh = self.get_cube_mesh_points()
179+
new = rbf(mesh)
180+
np.testing.assert_array_almost_equal(new[17], [8.947368e-01, 5.353524e-17, 8.845331e-03])
181+

0 commit comments

Comments
 (0)