5757"""
5858import os
5959import numpy as np
60+ try :
61+ import configparser as configparser
62+ except ImportError :
63+ import ConfigParser as configparser
64+
6065
6166from scipy .spatial .distance import cdist
6267
@@ -125,13 +130,7 @@ def __init__(self,
125130 radius = 0.5 ,
126131 extra_parameter = None ):
127132
128- if callable (func ):
129- self .basis = func
130- elif isinstance (func , str ):
131- self .basis = RBFFactory (func )
132- else :
133- raise TypeError ('`func` is not valid.' )
134-
133+ self .basis = func
135134 self .radius = radius
136135
137136 if original_control_points is None :
@@ -157,6 +156,7 @@ def __init__(self,
157156 self .weights = self ._get_weights (self .original_control_points ,
158157 self .deformed_control_points )
159158
159+
160160 @property
161161 def n_control_points (self ):
162162 """
@@ -166,6 +166,29 @@ def n_control_points(self):
166166 """
167167 return self .original_control_points .shape [0 ]
168168
169+ @property
170+ def basis (self ):
171+ """
172+ The kernel to use in the deformation.
173+
174+ :getter: Returns the callable kernel
175+ :setter: Sets the kernel. It is possible to pass the name of the
176+ function (check the list of all implemented functions in the
177+ `pygem.rbf_factory.RBFFactory` class) or directly the callable
178+ function.
179+ :type: callable
180+ """
181+ return self .__basis
182+
183+ @basis .setter
184+ def basis (self , func ):
185+ if callable (func ):
186+ self .__basis = func
187+ elif isinstance (func , str ):
188+ self .__basis = RBFFactory (func )
189+ else :
190+ raise TypeError ('`func` is not valid.' )
191+
169192 def _get_weights (self , X , Y ):
170193 """
171194 This private method, given the original control points and the deformed
@@ -185,7 +208,7 @@ def _get_weights(self, X, Y):
185208 """
186209 npts , dim = X .shape
187210 H = np .zeros ((npts + 3 + 1 , npts + 3 + 1 ))
188- H [:npts , :npts ] = self .basis (cdist (X , X ), self .radius , ** self .extra )
211+ H [:npts , :npts ] = self .basis (cdist (X , X ), self .radius ) # , **self.extra)
189212 H [npts , :npts ] = 1.0
190213 H [:npts , npts ] = 1.0
191214 H [:npts , - 3 :] = X
@@ -221,13 +244,14 @@ def read_parameters(self, filename='parameters_rbf.prm'):
221244
222245 ctrl_points = config .get ('Control points' , 'original control points' )
223246 lines = ctrl_points .split ('\n ' )
224- self . original_control_points = np .zeros ((len (lines ), 3 ))
247+ original_control_points = np .zeros ((len (lines ), 3 ))
225248 for line , i in zip (lines , list (range (0 , self .n_control_points ))):
226249 values = line .split ()
227- self . original_control_points [i ] = np .array (
250+ original_control_points [i ] = np .array (
228251 [float (values [0 ]),
229252 float (values [1 ]),
230253 float (values [2 ])])
254+ self .original_control_points = original_control_points
231255
232256 mod_points = config .get ('Control points' , 'deformed control points' )
233257 lines = mod_points .split ('\n ' )
@@ -238,13 +262,15 @@ def read_parameters(self, filename='parameters_rbf.prm'):
238262 "control points' section of the parameters file"
239263 "({0!s})" .format (filename ))
240264
241- self . deformed_control_points = np .zeros ((self .n_control_points , 3 ))
265+ deformed_control_points = np .zeros ((self .n_control_points , 3 ))
242266 for line , i in zip (lines , list (range (0 , self .n_control_points ))):
243267 values = line .split ()
244- self . deformed_control_points [i ] = np .array (
268+ deformed_control_points [i ] = np .array (
245269 [float (values [0 ]),
246270 float (values [1 ]),
247271 float (values [2 ])])
272+ self .deformed_control_points = deformed_control_points
273+
248274
249275 def write_parameters (self , filename = 'parameters_rbf.prm' ):
250276 """
@@ -271,7 +297,7 @@ def write_parameters(self, filename='parameters_rbf.prm'):
271297 output_string += ' polyharmonic_spline.\n '
272298 output_string += '# For a comprehensive list with details see the'
273299 output_string += ' class RBF.\n '
274- output_string += 'basis function: {}\n ' .format (str ( self . basis ) )
300+ output_string += 'basis function: {}\n ' .format ('gaussian_spline' )
275301
276302 output_string += '\n # radius is the scaling parameter r that affects'
277303 output_string += ' the shape of the basis functions. See the'
@@ -362,15 +388,26 @@ def plot_points(self, filename=None):
362388 else :
363389 fig .savefig (filename )
364390
391+ def compute_weights (self ):
392+ """
393+ This method compute the weights according to the
394+ `original_control_points` and `deformed_control_points` arrays.
395+ """
396+ self .weights = self ._get_weights (self .original_control_points ,
397+ self .deformed_control_points )
398+
365399 def __call__ (self , src_pts ):
366400 """
367401 This method performs the deformation of the mesh points. After the
368402 execution it sets `self.modified_mesh_points`.
369403 """
370- H = np .zeros ((n_mesh_points , self .n_control_points + 3 + 1 ))
404+ self .compute_weights ()
405+
406+ H = np .zeros ((src_pts .shape [0 ], self .n_control_points + 3 + 1 ))
371407 H [:, :self .n_control_points ] = self .basis (
372- cdist (src_pts , self .original_control_points ), self .radius ,
373- ** self .extra )
374- H [:, n_control_points ] = 1.0
375- H [:, - 3 :] = self .original_mesh_points
376- self .modified_mesh_points = np .asarray (np .dot (H , self .weights ))
408+ cdist (src_pts , self .original_control_points ),
409+ self .radius )
410+ #**self.extra)
411+ H [:, self .n_control_points ] = 1.0
412+ H [:, - 3 :] = src_pts
413+ return np .asarray (np .dot (H , self .weights ))
0 commit comments