7070
7171import matplotlib .pyplot as plt
7272
73+ import warnings
74+
7375
7476class RBF (Deformation ):
7577 """
@@ -93,8 +95,12 @@ class RBF(Deformation):
9395 basis functions. For details see the class
9496 :class:`RBF`. The default value is 0.5.
9597 :param dict extra_parameter: the additional parameters that may be passed to
96- the kernel function. Default is None.
97-
98+ the kernel function. Default is None.
99+ :param str dtype: Precision specification. Supported values:
100+ 'fp16'/'float16', 'fp32'/'float32', 'fp64'/'float64' (default),
101+ 'fp96'/'float96','fp128'/'float128' (if available on platform).
102+ Default is 'fp64'.
103+
98104 :cvar numpy.ndarray weights: the matrix formed by the weights corresponding
99105 to the a-priori selected N control points, associated to the basis
100106 functions and c and Q terms that describe the polynomial of order one
@@ -112,7 +118,7 @@ class RBF(Deformation):
112118 basis functions.
113119 :cvar dict extra: the additional parameters that may be passed to the
114120 kernel function.
115-
121+
116122 :Example:
117123
118124 >>> from pygem import RBF
@@ -125,12 +131,61 @@ class RBF(Deformation):
125131 >>> mesh = np.array([x.ravel(), y.ravel(), z.ravel()])
126132 >>> deformed_mesh = rbf(mesh)
127133 """
134+
135+ # Precision mapping
136+ DTYPE_MAP = {
137+ 'fp16' : np .float16 ,
138+ 'float16' : np .float16 ,
139+ 'fp32' : np .float32 ,
140+ 'float32' : np .float32 ,
141+ 'fp64' : np .float64 ,
142+ 'float64' : np .float64 ,
143+ 'fp96' : np .float96 if hasattr (np , 'float96' ) else np .float64 ,
144+ 'float96' : np .float96 if hasattr (np , 'float96' ) else np .float64 ,
145+ 'fp128' : np .float128 if hasattr (np , 'float128' ) else np .float64 ,
146+ 'float128' : np .float128 if hasattr (np , 'float128' ) else np .float64 ,
147+ }
148+
128149 def __init__ (self ,
129150 original_control_points = None ,
130151 deformed_control_points = None ,
131152 func = 'gaussian_spline' ,
132153 radius = 0.5 ,
133- extra_parameter = None ):
154+ extra_parameter = None ,
155+ dtype = 'fp64' ):
156+
157+ # Parse and set dtype with platform check
158+ if isinstance (dtype , str ):
159+ dtype_lower = dtype .lower ()
160+ if dtype_lower not in self .DTYPE_MAP :
161+ raise ValueError (
162+ f"Unsupported dtype '{ dtype } '. Supported values: "
163+ f"{ list (self .DTYPE_MAP .keys ())} "
164+ )
165+
166+ # Check for fp128 fallback
167+ if dtype_lower in ['fp128' , 'float128' ]:
168+ if not hasattr (np , 'float128' ):
169+ warnings .warn (
170+ "fp128/float128 is not supported on this platform. "
171+ "Automatically falling back to fp64. "
172+ "For true quad-precision, consider using Linux platform." ,
173+ RuntimeWarning
174+ )
175+
176+ # Check for fp96 fallback
177+ if dtype_lower in ['fp96' , 'float96' ]:
178+ if not hasattr (np , 'float96' ):
179+ warnings .warn (
180+ "fp96/float96 is not supported on this platform. "
181+ "Automatically falling back to fp64. "
182+ "For higher precision consider using 'fp128' (if available) " ,
183+ RuntimeWarning
184+ )
185+
186+ self ._dtype = self .DTYPE_MAP [dtype_lower ]
187+ else :
188+ self ._dtype = dtype
134189
135190 self .basis = func
136191 self .radius = radius
@@ -139,26 +194,25 @@ def __init__(self,
139194 self .original_control_points = np .array ([[0. , 0. , 0. ], [0. , 0. , 1. ],
140195 [0. , 1. , 0. ], [1. , 0. , 0. ],
141196 [0. , 1. , 1. ], [1. , 0. , 1. ],
142- [1. , 1. , 0. ], [1. , 1. ,
143- 1. ]] )
197+ [1. , 1. , 0. ], [1. , 1. , 1. ]],
198+ dtype = self . _dtype )
144199 else :
145- self .original_control_points = original_control_points
200+ self .original_control_points = np . asarray ( original_control_points , dtype = self . _dtype )
146201
147202 if deformed_control_points is None :
148203 self .deformed_control_points = np .array ([[0. , 0. , 0. ], [0. , 0. , 1. ],
149204 [0. , 1. , 0. ], [1. , 0. , 0. ],
150205 [0. , 1. , 1. ], [1. , 0. , 1. ],
151- [1. , 1. , 0. ], [1. , 1. ,
152- 1. ]] )
206+ [1. , 1. , 0. ], [1. , 1. , 1. ]],
207+ dtype = self . _dtype )
153208 else :
154- self .deformed_control_points = deformed_control_points
209+ self .deformed_control_points = np . asarray ( deformed_control_points , dtype = self . _dtype )
155210
156211 self .extra = extra_parameter if extra_parameter else dict ()
157212
158213 self .weights = self ._get_weights (self .original_control_points ,
159214 self .deformed_control_points )
160215
161-
162216 @property
163217 def n_control_points (self ):
164218 """
@@ -209,17 +263,26 @@ def _get_weights(self, X, Y):
209263 :rtype: numpy.ndarray
210264 """
211265 npts , dim = X .shape
212- H = np .zeros ((npts + 3 + 1 , npts + 3 + 1 ))
213- H [:npts , :npts ] = self .basis (cdist (X , X ), self .radius , ** self .extra )
214- H [npts , :npts ] = 1.0
215- H [:npts , npts ] = 1.0
266+ size = npts + 3 + 1
267+ H = np .zeros ((size , size ), dtype = self ._dtype )
268+
269+ # Compute distances and basis values using configured precision
270+ dists = cdist (X , X ).astype (self ._dtype )
271+ basis_block = self .basis (dists , self .radius , ** self .extra )
272+ basis_block = np .asarray (basis_block , dtype = self ._dtype )
273+
274+ H [:npts , :npts ] = basis_block
275+ H [npts , :npts ] = self ._dtype (1.0 )
276+ H [:npts , npts ] = self ._dtype (1.0 )
216277 H [:npts , - 3 :] = X
217278 H [- 3 :, :npts ] = X .T
218279
219- rhs = np .zeros ((npts + 3 + 1 , dim ))
280+ rhs = np .zeros ((size , dim ), dtype = self . _dtype )
220281 rhs [:npts , :] = Y
221- weights = np .linalg .solve (H , rhs )
222- return weights
282+
283+ solve_dtype = np .float64 if self ._dtype not in (np .float32 , np .float64 ) else self ._dtype
284+ weights = np .linalg .solve (H .astype (solve_dtype ), rhs .astype (solve_dtype )).astype (self ._dtype )
285+ return weights .astype (self ._dtype )
223286
224287 def read_parameters (self , filename = 'parameters_rbf.prm' ):
225288 """
@@ -242,20 +305,20 @@ def read_parameters(self, filename='parameters_rbf.prm'):
242305 config .read (filename )
243306
244307 rbf_settings = dict (config .items ('Radial Basis Functions' ))
245-
308+
246309 self .basis = rbf_settings .pop ('basis function' )
247310 self .radius = float (rbf_settings .pop ('radius' ))
248311 self .extra = {k : eval (v ) for k , v in rbf_settings .items ()}
249312
250313 ctrl_points = config .get ('Control points' , 'original control points' )
251314 lines = ctrl_points .split ('\n ' )
252315 self .original_control_points = np .array (
253- list (map (lambda x : x .split (), lines )), dtype = float )
316+ list (map (lambda x : x .split (), lines )), dtype = self . _dtype )
254317
255318 mod_points = config .get ('Control points' , 'deformed control points' )
256319 lines = mod_points .split ('\n ' )
257320 self .deformed_control_points = np .array (
258- list (map (lambda x : x .split (), lines )), dtype = float )
321+ list (map (lambda x : x .split (), lines )), dtype = self . _dtype )
259322
260323 if len (lines ) != self .n_control_points :
261324 raise TypeError ("The number of control points must be equal both in"
@@ -308,8 +371,8 @@ def write_parameters(self, filename='parameters_rbf.prm'):
308371 for i in range (0 , self .n_control_points ):
309372 output_string += offset * ' ' + str (
310373 self .original_control_points [i ][0 ]) + ' ' + str (
311- self .original_control_points [i ][1 ]) + ' ' + str (
312- self .original_control_points [i ][2 ]) + '\n '
374+ self .original_control_points [i ][1 ]) + ' ' + str (
375+ self .original_control_points [i ][2 ]) + '\n '
313376 offset = 25
314377
315378 output_string += '\n # deformed control points collects the coordinates'
@@ -321,8 +384,8 @@ def write_parameters(self, filename='parameters_rbf.prm'):
321384 for i in range (0 , self .n_control_points ):
322385 output_string += offset * ' ' + str (
323386 self .deformed_control_points [i ][0 ]) + ' ' + str (
324- self .deformed_control_points [i ][1 ]) + ' ' + str (
325- self .deformed_control_points [i ][2 ]) + '\n '
387+ self .deformed_control_points [i ][1 ]) + ' ' + str (
388+ self .deformed_control_points [i ][2 ]) + '\n '
326389 offset = 25
327390
328391 with open (filename , 'w' ) as f :
@@ -393,110 +456,18 @@ def __call__(self, src_pts):
393456 This method performs the deformation of the mesh points. After the
394457 execution it sets `self.modified_mesh_points`.
395458 """
396- self .compute_weights ()
397-
398- H = np .zeros ((src_pts .shape [0 ], self .n_control_points + 3 + 1 ))
399- H [:, :self .n_control_points ] = self .basis (
400- cdist (src_pts , self .original_control_points ),
401- self .radius ,
402- ** self .extra )
403- H [:, self .n_control_points ] = 1.0
404- H [:, - 3 :] = src_pts
405- return np .asarray (np .dot (H , self .weights ))
406-
407- class RBFSinglePrecision (RBF ):
408- """
409- Memory-optimized RBF that stores and computes large matrices in single
410- precision (float32). Other behavior matches `RBF`.
411-
412- Use this class when memory is constrained; results remain in float32.
413- """
414-
415- def __init__ (self ,
416- original_control_points = None ,
417- deformed_control_points = None ,
418- func = 'gaussian_spline' ,
419- radius = 0.5 ,
420- extra_parameter = None ,
421- dtype = np .float32 ):
422-
423- # store desired dtype for heavy arrays
424- self ._dtype = dtype
425- # set basis and radius using parent property setters
426- self .basis = func
427- self .radius = radius
428-
429- # initialize control points in single precision
430- if original_control_points is None :
431- self .original_control_points = np .array (
432- [[0. , 0. , 0. ], [0. , 0. , 1. ], [0. , 1. , 0. ], [1. , 0. , 0. ],
433- [0. , 1. , 1. ], [1. , 0. , 1. ], [1. , 1. , 0. ], [1. , 1. , 1. ]],
434- dtype = self ._dtype )
435- else :
436- self .original_control_points = np .asarray (original_control_points ,
437- dtype = self ._dtype )
438-
439- if deformed_control_points is None :
440- self .deformed_control_points = np .array (
441- [[0. , 0. , 0. ], [0. , 0. , 1. ], [0. , 1. , 0. ], [1. , 0. , 0. ],
442- [0. , 1. , 1. ], [1. , 0. , 1. ], [1. , 1. , 0. ], [1. , 1. , 1. ]],
443- dtype = self ._dtype )
444- else :
445- self .deformed_control_points = np .asarray (deformed_control_points ,
446- dtype = self ._dtype )
447-
448- # extra parameters (small), keep as provided
449- self .extra = extra_parameter if extra_parameter else dict ()
450-
451- # compute weights in single precision
452- self .weights = self ._get_weights (self .original_control_points ,
453- self .deformed_control_points )
454-
455- def _get_weights (self , X , Y ):
456- """
457- Single-precision version of weight computation. Large matrices (H, rhs,
458- basis evaluations) use float32 to reduce memory usage.
459- """
460- npts , dim = X .shape
461- size = npts + 3 + 1
462- H = np .zeros ((size , size ), dtype = self ._dtype )
463-
464- # compute pairwise distances then cast to single precision
465- dists = cdist (X .astype (np .float64 ), X .astype (np .float64 )).astype (self ._dtype )
466- basis_block = self .basis (dists , self .radius , ** self .extra )
467- # ensure basis_block is single precision
468- basis_block = np .asarray (basis_block , dtype = self ._dtype )
469- H [:npts , :npts ] = basis_block
470-
471- H [npts , :npts ] = self ._dtype (1.0 )
472- H [:npts , npts ] = self ._dtype (1.0 )
473- H [:npts , - 3 :] = X
474- H [- 3 :, :npts ] = X .T
475-
476- rhs = np .zeros ((size , dim ), dtype = self ._dtype )
477- rhs [:npts , :] = Y
478-
479- # solve in single precision
480- weights = np .linalg .solve (H .astype (self ._dtype ), rhs .astype (self ._dtype ))
481- return weights .astype (self ._dtype )
482-
483- def __call__ (self , src_pts ):
484- """
485- Deform `src_pts`. Heavy temporary arrays are single precision.
486- """
487- # ensure src_pts in single precision for computations
488459 src = np .asarray (src_pts , dtype = self ._dtype )
489- # recompute weights to keep consistency with parent API
490- self .weights = self ._get_weights (self .original_control_points ,
491- self .deformed_control_points )
460+ self .compute_weights ()
492461
493462 H = np .zeros ((src .shape [0 ], self .n_control_points + 3 + 1 ),
494463 dtype = self ._dtype )
495464
496- dists = cdist (src . astype ( np . float64 ) , self .original_control_points . astype ( np . float64 ) ).astype (self ._dtype )
465+ dists = cdist (src , self .original_control_points ).astype (self ._dtype )
497466 basis_block = self .basis (dists , self .radius , ** self .extra )
467+
498468 H [:, :self .n_control_points ] = np .asarray (basis_block , dtype = self ._dtype )
499469 H [:, self .n_control_points ] = self ._dtype (1.0 )
500470 H [:, - 3 :] = src
471+
501472 result = np .dot (H , self .weights )
502473 return np .asarray (result , dtype = self ._dtype )
0 commit comments