Skip to content

Commit 3cd99bf

Browse files
Icbearsndem0
authored andcommitted
Add RBFSinglePrecision for a single precision option
Introduced RBFSinglePrecision class for memory optimization.
1 parent 82faf7a commit 3cd99bf

File tree

1 file changed

+97
-0
lines changed

1 file changed

+97
-0
lines changed

pygem/rbf.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,3 +403,100 @@ def __call__(self, src_pts):
403403
H[:, self.n_control_points] = 1.0
404404
H[:, -3:] = src_pts
405405
return np.asarray(np.dot(H, self.weights))
406+
407+
class RBFSinglePrecision(RBF):
408+
"""
409+
Memory-optimized RBF that stores and computes large matrices in single
410+
precision (float32). Other behavior matches `RBF`.
411+
412+
Use this class when memory is constrained; results remain in float32.
413+
"""
414+
415+
def __init__(self,
416+
original_control_points=None,
417+
deformed_control_points=None,
418+
func='gaussian_spline',
419+
radius=0.5,
420+
extra_parameter=None,
421+
dtype=np.float32):
422+
423+
# store desired dtype for heavy arrays
424+
self._dtype = dtype
425+
# set basis and radius using parent property setters
426+
self.basis = func
427+
self.radius = radius
428+
429+
# initialize control points in single precision
430+
if original_control_points is None:
431+
self.original_control_points = np.array(
432+
[[0., 0., 0.], [0., 0., 1.], [0., 1., 0.], [1., 0., 0.],
433+
[0., 1., 1.], [1., 0., 1.], [1., 1., 0.], [1., 1., 1.]],
434+
dtype=self._dtype)
435+
else:
436+
self.original_control_points = np.asarray(original_control_points,
437+
dtype=self._dtype)
438+
439+
if deformed_control_points is None:
440+
self.deformed_control_points = np.array(
441+
[[0., 0., 0.], [0., 0., 1.], [0., 1., 0.], [1., 0., 0.],
442+
[0., 1., 1.], [1., 0., 1.], [1., 1., 0.], [1., 1., 1.]],
443+
dtype=self._dtype)
444+
else:
445+
self.deformed_control_points = np.asarray(deformed_control_points,
446+
dtype=self._dtype)
447+
448+
# extra parameters (small), keep as provided
449+
self.extra = extra_parameter if extra_parameter else dict()
450+
451+
# compute weights in single precision
452+
self.weights = self._get_weights(self.original_control_points,
453+
self.deformed_control_points)
454+
455+
def _get_weights(self, X, Y):
456+
"""
457+
Single-precision version of weight computation. Large matrices (H, rhs,
458+
basis evaluations) use float32 to reduce memory usage.
459+
"""
460+
npts, dim = X.shape
461+
size = npts + 3 + 1
462+
H = np.zeros((size, size), dtype=self._dtype)
463+
464+
# compute pairwise distances then cast to single precision
465+
dists = cdist(X.astype(np.float64), X.astype(np.float64)).astype(self._dtype)
466+
basis_block = self.basis(dists, self.radius, **self.extra)
467+
# ensure basis_block is single precision
468+
basis_block = np.asarray(basis_block, dtype=self._dtype)
469+
H[:npts, :npts] = basis_block
470+
471+
H[npts, :npts] = self._dtype(1.0)
472+
H[:npts, npts] = self._dtype(1.0)
473+
H[:npts, -3:] = X
474+
H[-3:, :npts] = X.T
475+
476+
rhs = np.zeros((size, dim), dtype=self._dtype)
477+
rhs[:npts, :] = Y
478+
479+
# solve in single precision
480+
weights = np.linalg.solve(H.astype(self._dtype), rhs.astype(self._dtype))
481+
return weights.astype(self._dtype)
482+
483+
def __call__(self, src_pts):
484+
"""
485+
Deform `src_pts`. Heavy temporary arrays are single precision.
486+
"""
487+
# ensure src_pts in single precision for computations
488+
src = np.asarray(src_pts, dtype=self._dtype)
489+
# recompute weights to keep consistency with parent API
490+
self.weights = self._get_weights(self.original_control_points,
491+
self.deformed_control_points)
492+
493+
H = np.zeros((src.shape[0], self.n_control_points + 3 + 1),
494+
dtype=self._dtype)
495+
496+
dists = cdist(src.astype(np.float64), self.original_control_points.astype(np.float64)).astype(self._dtype)
497+
basis_block = self.basis(dists, self.radius, **self.extra)
498+
H[:, :self.n_control_points] = np.asarray(basis_block, dtype=self._dtype)
499+
H[:, self.n_control_points] = self._dtype(1.0)
500+
H[:, -3:] = src
501+
result = np.dot(H, self.weights)
502+
return np.asarray(result, dtype=self._dtype)

0 commit comments

Comments
 (0)