@@ -403,3 +403,100 @@ def __call__(self, src_pts):
403403 H [:, self .n_control_points ] = 1.0
404404 H [:, - 3 :] = src_pts
405405 return np .asarray (np .dot (H , self .weights ))
406+
407+ class RBFSinglePrecision (RBF ):
408+ """
409+ Memory-optimized RBF that stores and computes large matrices in single
410+ precision (float32). Other behavior matches `RBF`.
411+
412+ Use this class when memory is constrained; results remain in float32.
413+ """
414+
415+ def __init__ (self ,
416+ original_control_points = None ,
417+ deformed_control_points = None ,
418+ func = 'gaussian_spline' ,
419+ radius = 0.5 ,
420+ extra_parameter = None ,
421+ dtype = np .float32 ):
422+
423+ # store desired dtype for heavy arrays
424+ self ._dtype = dtype
425+ # set basis and radius using parent property setters
426+ self .basis = func
427+ self .radius = radius
428+
429+ # initialize control points in single precision
430+ if original_control_points is None :
431+ self .original_control_points = np .array (
432+ [[0. , 0. , 0. ], [0. , 0. , 1. ], [0. , 1. , 0. ], [1. , 0. , 0. ],
433+ [0. , 1. , 1. ], [1. , 0. , 1. ], [1. , 1. , 0. ], [1. , 1. , 1. ]],
434+ dtype = self ._dtype )
435+ else :
436+ self .original_control_points = np .asarray (original_control_points ,
437+ dtype = self ._dtype )
438+
439+ if deformed_control_points is None :
440+ self .deformed_control_points = np .array (
441+ [[0. , 0. , 0. ], [0. , 0. , 1. ], [0. , 1. , 0. ], [1. , 0. , 0. ],
442+ [0. , 1. , 1. ], [1. , 0. , 1. ], [1. , 1. , 0. ], [1. , 1. , 1. ]],
443+ dtype = self ._dtype )
444+ else :
445+ self .deformed_control_points = np .asarray (deformed_control_points ,
446+ dtype = self ._dtype )
447+
448+ # extra parameters (small), keep as provided
449+ self .extra = extra_parameter if extra_parameter else dict ()
450+
451+ # compute weights in single precision
452+ self .weights = self ._get_weights (self .original_control_points ,
453+ self .deformed_control_points )
454+
455+ def _get_weights (self , X , Y ):
456+ """
457+ Single-precision version of weight computation. Large matrices (H, rhs,
458+ basis evaluations) use float32 to reduce memory usage.
459+ """
460+ npts , dim = X .shape
461+ size = npts + 3 + 1
462+ H = np .zeros ((size , size ), dtype = self ._dtype )
463+
464+ # compute pairwise distances then cast to single precision
465+ dists = cdist (X .astype (np .float64 ), X .astype (np .float64 )).astype (self ._dtype )
466+ basis_block = self .basis (dists , self .radius , ** self .extra )
467+ # ensure basis_block is single precision
468+ basis_block = np .asarray (basis_block , dtype = self ._dtype )
469+ H [:npts , :npts ] = basis_block
470+
471+ H [npts , :npts ] = self ._dtype (1.0 )
472+ H [:npts , npts ] = self ._dtype (1.0 )
473+ H [:npts , - 3 :] = X
474+ H [- 3 :, :npts ] = X .T
475+
476+ rhs = np .zeros ((size , dim ), dtype = self ._dtype )
477+ rhs [:npts , :] = Y
478+
479+ # solve in single precision
480+ weights = np .linalg .solve (H .astype (self ._dtype ), rhs .astype (self ._dtype ))
481+ return weights .astype (self ._dtype )
482+
483+ def __call__ (self , src_pts ):
484+ """
485+ Deform `src_pts`. Heavy temporary arrays are single precision.
486+ """
487+ # ensure src_pts in single precision for computations
488+ src = np .asarray (src_pts , dtype = self ._dtype )
489+ # recompute weights to keep consistency with parent API
490+ self .weights = self ._get_weights (self .original_control_points ,
491+ self .deformed_control_points )
492+
493+ H = np .zeros ((src .shape [0 ], self .n_control_points + 3 + 1 ),
494+ dtype = self ._dtype )
495+
496+ dists = cdist (src .astype (np .float64 ), self .original_control_points .astype (np .float64 )).astype (self ._dtype )
497+ basis_block = self .basis (dists , self .radius , ** self .extra )
498+ H [:, :self .n_control_points ] = np .asarray (basis_block , dtype = self ._dtype )
499+ H [:, self .n_control_points ] = self ._dtype (1.0 )
500+ H [:, - 3 :] = src
501+ result = np .dot (H , self .weights )
502+ return np .asarray (result , dtype = self ._dtype )
0 commit comments