Skip to content

Commit 07495dc

Browse files
authored
Add dtype parameter for precision control in RBF
Added dtype parameter for precision control and updated related methods to handle different data types (fp16, fp32, fp64, fp96, and fp128), with default dtype fp64. Added warnings for unsupported precision types.
1 parent 933194a commit 07495dc

File tree

1 file changed

+92
-121
lines changed

1 file changed

+92
-121
lines changed

pygem/rbf.py

Lines changed: 92 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@
7070

7171
import matplotlib.pyplot as plt
7272

73+
import warnings
74+
7375

7476
class RBF(Deformation):
7577
"""
@@ -93,8 +95,12 @@ class RBF(Deformation):
9395
basis functions. For details see the class
9496
:class:`RBF`. The default value is 0.5.
9597
:param dict extra_parameter: the additional parameters that may be passed to
96-
the kernel function. Default is None.
97-
98+
the kernel function. Default is None.
99+
:param str dtype: Precision specification. Supported values:
100+
'fp16'/'float16', 'fp32'/'float32', 'fp64'/'float64' (default),
101+
'fp96'/'float96','fp128'/'float128' (if available on platform).
102+
Default is 'fp64'.
103+
98104
:cvar numpy.ndarray weights: the matrix formed by the weights corresponding
99105
to the a-priori selected N control points, associated to the basis
100106
functions and c and Q terms that describe the polynomial of order one
@@ -112,7 +118,7 @@ class RBF(Deformation):
112118
basis functions.
113119
:cvar dict extra: the additional parameters that may be passed to the
114120
kernel function.
115-
121+
116122
:Example:
117123
118124
>>> from pygem import RBF
@@ -125,12 +131,61 @@ class RBF(Deformation):
125131
>>> mesh = np.array([x.ravel(), y.ravel(), z.ravel()])
126132
>>> deformed_mesh = rbf(mesh)
127133
"""
134+
135+
# Precision mapping
136+
DTYPE_MAP = {
137+
'fp16': np.float16,
138+
'float16': np.float16,
139+
'fp32': np.float32,
140+
'float32': np.float32,
141+
'fp64': np.float64,
142+
'float64': np.float64,
143+
'fp96': np.float96 if hasattr(np, 'float96') else np.float64,
144+
'float96': np.float96 if hasattr(np, 'float96') else np.float64,
145+
'fp128': np.float128 if hasattr(np, 'float128') else np.float64,
146+
'float128': np.float128 if hasattr(np, 'float128') else np.float64,
147+
}
148+
128149
def __init__(self,
129150
original_control_points=None,
130151
deformed_control_points=None,
131152
func='gaussian_spline',
132153
radius=0.5,
133-
extra_parameter=None):
154+
extra_parameter=None,
155+
dtype='fp64'):
156+
157+
# Parse and set dtype with platform check
158+
if isinstance(dtype, str):
159+
dtype_lower = dtype.lower()
160+
if dtype_lower not in self.DTYPE_MAP:
161+
raise ValueError(
162+
f"Unsupported dtype '{dtype}'. Supported values: "
163+
f"{list(self.DTYPE_MAP.keys())}"
164+
)
165+
166+
# Check for fp128 fallback
167+
if dtype_lower in ['fp128', 'float128']:
168+
if not hasattr(np, 'float128'):
169+
warnings.warn(
170+
"fp128/float128 is not supported on this platform. "
171+
"Automatically falling back to fp64. "
172+
"For true quad-precision, consider using Linux platform.",
173+
RuntimeWarning
174+
)
175+
176+
# Check for fp96 fallback
177+
if dtype_lower in ['fp96', 'float96']:
178+
if not hasattr(np, 'float96'):
179+
warnings.warn(
180+
"fp96/float96 is not supported on this platform. "
181+
"Automatically falling back to fp64. "
182+
"For higher precision consider using 'fp128' (if available) ",
183+
RuntimeWarning
184+
)
185+
186+
self._dtype = self.DTYPE_MAP[dtype_lower]
187+
else:
188+
self._dtype = dtype
134189

135190
self.basis = func
136191
self.radius = radius
@@ -139,26 +194,25 @@ def __init__(self,
139194
self.original_control_points = np.array([[0., 0., 0.], [0., 0., 1.],
140195
[0., 1., 0.], [1., 0., 0.],
141196
[0., 1., 1.], [1., 0., 1.],
142-
[1., 1., 0.], [1., 1.,
143-
1.]])
197+
[1., 1., 0.], [1., 1., 1.]],
198+
dtype=self._dtype)
144199
else:
145-
self.original_control_points = original_control_points
200+
self.original_control_points = np.asarray(original_control_points, dtype=self._dtype)
146201

147202
if deformed_control_points is None:
148203
self.deformed_control_points = np.array([[0., 0., 0.], [0., 0., 1.],
149204
[0., 1., 0.], [1., 0., 0.],
150205
[0., 1., 1.], [1., 0., 1.],
151-
[1., 1., 0.], [1., 1.,
152-
1.]])
206+
[1., 1., 0.], [1., 1., 1.]],
207+
dtype=self._dtype)
153208
else:
154-
self.deformed_control_points = deformed_control_points
209+
self.deformed_control_points = np.asarray(deformed_control_points, dtype=self._dtype)
155210

156211
self.extra = extra_parameter if extra_parameter else dict()
157212

158213
self.weights = self._get_weights(self.original_control_points,
159214
self.deformed_control_points)
160215

161-
162216
@property
163217
def n_control_points(self):
164218
"""
@@ -209,17 +263,26 @@ def _get_weights(self, X, Y):
209263
:rtype: numpy.ndarray
210264
"""
211265
npts, dim = X.shape
212-
H = np.zeros((npts + 3 + 1, npts + 3 + 1))
213-
H[:npts, :npts] = self.basis(cdist(X, X), self.radius, **self.extra)
214-
H[npts, :npts] = 1.0
215-
H[:npts, npts] = 1.0
266+
size = npts + 3 + 1
267+
H = np.zeros((size, size), dtype=self._dtype)
268+
269+
# Compute distances and basis values using configured precision
270+
dists = cdist(X, X).astype(self._dtype)
271+
basis_block = self.basis(dists, self.radius, **self.extra)
272+
basis_block = np.asarray(basis_block, dtype=self._dtype)
273+
274+
H[:npts, :npts] = basis_block
275+
H[npts, :npts] = self._dtype(1.0)
276+
H[:npts, npts] = self._dtype(1.0)
216277
H[:npts, -3:] = X
217278
H[-3:, :npts] = X.T
218279

219-
rhs = np.zeros((npts + 3 + 1, dim))
280+
rhs = np.zeros((size, dim), dtype=self._dtype)
220281
rhs[:npts, :] = Y
221-
weights = np.linalg.solve(H, rhs)
222-
return weights
282+
283+
solve_dtype = np.float64 if self._dtype not in (np.float32, np.float64) else self._dtype
284+
weights = np.linalg.solve(H.astype(solve_dtype), rhs.astype(solve_dtype)).astype(self._dtype)
285+
return weights.astype(self._dtype)
223286

224287
def read_parameters(self, filename='parameters_rbf.prm'):
225288
"""
@@ -242,20 +305,20 @@ def read_parameters(self, filename='parameters_rbf.prm'):
242305
config.read(filename)
243306

244307
rbf_settings = dict(config.items('Radial Basis Functions'))
245-
308+
246309
self.basis = rbf_settings.pop('basis function')
247310
self.radius = float(rbf_settings.pop('radius'))
248311
self.extra = {k: eval(v) for k, v in rbf_settings.items()}
249312

250313
ctrl_points = config.get('Control points', 'original control points')
251314
lines = ctrl_points.split('\n')
252315
self.original_control_points = np.array(
253-
list(map(lambda x: x.split(), lines)), dtype=float)
316+
list(map(lambda x: x.split(), lines)), dtype=self._dtype)
254317

255318
mod_points = config.get('Control points', 'deformed control points')
256319
lines = mod_points.split('\n')
257320
self.deformed_control_points = np.array(
258-
list(map(lambda x: x.split(), lines)), dtype=float)
321+
list(map(lambda x: x.split(), lines)), dtype=self._dtype)
259322

260323
if len(lines) != self.n_control_points:
261324
raise TypeError("The number of control points must be equal both in"
@@ -308,8 +371,8 @@ def write_parameters(self, filename='parameters_rbf.prm'):
308371
for i in range(0, self.n_control_points):
309372
output_string += offset * ' ' + str(
310373
self.original_control_points[i][0]) + ' ' + str(
311-
self.original_control_points[i][1]) + ' ' + str(
312-
self.original_control_points[i][2]) + '\n'
374+
self.original_control_points[i][1]) + ' ' + str(
375+
self.original_control_points[i][2]) + '\n'
313376
offset = 25
314377

315378
output_string += '\n# deformed control points collects the coordinates'
@@ -321,8 +384,8 @@ def write_parameters(self, filename='parameters_rbf.prm'):
321384
for i in range(0, self.n_control_points):
322385
output_string += offset * ' ' + str(
323386
self.deformed_control_points[i][0]) + ' ' + str(
324-
self.deformed_control_points[i][1]) + ' ' + str(
325-
self.deformed_control_points[i][2]) + '\n'
387+
self.deformed_control_points[i][1]) + ' ' + str(
388+
self.deformed_control_points[i][2]) + '\n'
326389
offset = 25
327390

328391
with open(filename, 'w') as f:
@@ -393,110 +456,18 @@ def __call__(self, src_pts):
393456
This method performs the deformation of the mesh points. After the
394457
execution it sets `self.modified_mesh_points`.
395458
"""
396-
self.compute_weights()
397-
398-
H = np.zeros((src_pts.shape[0], self.n_control_points + 3 + 1))
399-
H[:, :self.n_control_points] = self.basis(
400-
cdist(src_pts, self.original_control_points),
401-
self.radius,
402-
**self.extra)
403-
H[:, self.n_control_points] = 1.0
404-
H[:, -3:] = src_pts
405-
return np.asarray(np.dot(H, self.weights))
406-
407-
class RBFSinglePrecision(RBF):
408-
"""
409-
Memory-optimized RBF that stores and computes large matrices in single
410-
precision (float32). Other behavior matches `RBF`.
411-
412-
Use this class when memory is constrained; results remain in float32.
413-
"""
414-
415-
def __init__(self,
416-
original_control_points=None,
417-
deformed_control_points=None,
418-
func='gaussian_spline',
419-
radius=0.5,
420-
extra_parameter=None,
421-
dtype=np.float32):
422-
423-
# store desired dtype for heavy arrays
424-
self._dtype = dtype
425-
# set basis and radius using parent property setters
426-
self.basis = func
427-
self.radius = radius
428-
429-
# initialize control points in single precision
430-
if original_control_points is None:
431-
self.original_control_points = np.array(
432-
[[0., 0., 0.], [0., 0., 1.], [0., 1., 0.], [1., 0., 0.],
433-
[0., 1., 1.], [1., 0., 1.], [1., 1., 0.], [1., 1., 1.]],
434-
dtype=self._dtype)
435-
else:
436-
self.original_control_points = np.asarray(original_control_points,
437-
dtype=self._dtype)
438-
439-
if deformed_control_points is None:
440-
self.deformed_control_points = np.array(
441-
[[0., 0., 0.], [0., 0., 1.], [0., 1., 0.], [1., 0., 0.],
442-
[0., 1., 1.], [1., 0., 1.], [1., 1., 0.], [1., 1., 1.]],
443-
dtype=self._dtype)
444-
else:
445-
self.deformed_control_points = np.asarray(deformed_control_points,
446-
dtype=self._dtype)
447-
448-
# extra parameters (small), keep as provided
449-
self.extra = extra_parameter if extra_parameter else dict()
450-
451-
# compute weights in single precision
452-
self.weights = self._get_weights(self.original_control_points,
453-
self.deformed_control_points)
454-
455-
def _get_weights(self, X, Y):
456-
"""
457-
Single-precision version of weight computation. Large matrices (H, rhs,
458-
basis evaluations) use float32 to reduce memory usage.
459-
"""
460-
npts, dim = X.shape
461-
size = npts + 3 + 1
462-
H = np.zeros((size, size), dtype=self._dtype)
463-
464-
# compute pairwise distances then cast to single precision
465-
dists = cdist(X.astype(np.float64), X.astype(np.float64)).astype(self._dtype)
466-
basis_block = self.basis(dists, self.radius, **self.extra)
467-
# ensure basis_block is single precision
468-
basis_block = np.asarray(basis_block, dtype=self._dtype)
469-
H[:npts, :npts] = basis_block
470-
471-
H[npts, :npts] = self._dtype(1.0)
472-
H[:npts, npts] = self._dtype(1.0)
473-
H[:npts, -3:] = X
474-
H[-3:, :npts] = X.T
475-
476-
rhs = np.zeros((size, dim), dtype=self._dtype)
477-
rhs[:npts, :] = Y
478-
479-
# solve in single precision
480-
weights = np.linalg.solve(H.astype(self._dtype), rhs.astype(self._dtype))
481-
return weights.astype(self._dtype)
482-
483-
def __call__(self, src_pts):
484-
"""
485-
Deform `src_pts`. Heavy temporary arrays are single precision.
486-
"""
487-
# ensure src_pts in single precision for computations
488459
src = np.asarray(src_pts, dtype=self._dtype)
489-
# recompute weights to keep consistency with parent API
490-
self.weights = self._get_weights(self.original_control_points,
491-
self.deformed_control_points)
460+
self.compute_weights()
492461

493462
H = np.zeros((src.shape[0], self.n_control_points + 3 + 1),
494463
dtype=self._dtype)
495464

496-
dists = cdist(src.astype(np.float64), self.original_control_points.astype(np.float64)).astype(self._dtype)
465+
dists = cdist(src, self.original_control_points).astype(self._dtype)
497466
basis_block = self.basis(dists, self.radius, **self.extra)
467+
498468
H[:, :self.n_control_points] = np.asarray(basis_block, dtype=self._dtype)
499469
H[:, self.n_control_points] = self._dtype(1.0)
500470
H[:, -3:] = src
471+
501472
result = np.dot(H, self.weights)
502473
return np.asarray(result, dtype=self._dtype)

0 commit comments

Comments
 (0)