@@ -42,7 +42,10 @@ class CFFD(FFD):
4242 :cvar numpy.ndarray fun_mask: a boolean tensor that tells to the class
4343 on which axis which constraint depends on. The tensor has shape (n_cons,3), where the last dimension indicates dependency on
4444 on x,y,z respectively. Default is all true. It used only in the triaffine mode.
45-
45+ :cvar numpy.ndarray weight_matrix: a symmetric positive definite weigth matrix.
46+ It must be of row and column size the number of trues in the mask.
47+ It weights the movemement of the control points which have a true flag in the ffd_mask.
48+ Default is identity.
4649
4750 :Example:
4851
@@ -64,7 +67,7 @@ def __init__(self,
6467 fun ,
6568 n_control_points = None ,
6669 ffd_mask = None ,
67- fun_mask = None ):
70+ fun_mask = None , weight_matrix = None ):
6871 super ().__init__ (n_control_points )
6972
7073 if ffd_mask is None :
@@ -79,22 +82,23 @@ def __init__(self,
7982 self .fun_mask = np .full ((self .num_cons , 3 ), True , dtype = bool )
8083 else :
8184 self .fun_mask = fun_mask
82- def adjust_control_points (self ,src_pts ):
85+
86+ if weight_matrix is None :
87+ self .weight_matrix = np .eye (np .sum (self .ffd_mask ))
88+ def _adjust_control_points_inner (self ,src_pts ,hyper_param ):
8389 '''
84- Adjust the FFD control points such that fun(ffd(src_pts))=fixval
90+ Adjust the FFD control points such that fun(ffd(src_pts))=fixval given a hyperparameter.
8591
8692 :param np.ndarray src_pts: the points whose deformation we want to be
8793 constrained.
8894 :rtype: None.
8995 '''
90- vweight = self .fun_mask .copy ().astype (float )
91- vweight = vweight / np .sum (vweight ,axis = 1 )
9296 mask_bak = self .ffd_mask .copy ()
9397 diffvolume = self .fixval - self .fun (self .ffd (src_pts ))
9498 for i in range (3 ):
9599 self .ffd_mask = np .full ((* self .n_control_points , 3 ), False , dtype = bool )
96100 self .ffd_mask [:, :, :, i ] = mask_bak [:, :, :, i ].copy ()
97- self .fixval = self .fun (self .ffd (src_pts )) + vweight [:,i ] * (
101+ self .fixval = self .fun (self .ffd (src_pts )) + hyper_param [:,i ] * (
98102 diffvolume
99103 )
100104 saved_parameters = self ._save_parameters ()
@@ -116,7 +120,17 @@ def adjust_control_points(self,src_pts):
116120 self .ffd_mask = mask_bak .copy ()
117121
118122
119-
123+ def adjust_control_points (self ,src_pts ):
124+ '''
125+ Adjust the FFD control points such that fun(ffd(src_pts))=fixval
126+
127+ :param np.ndarray src_pts: the points whose deformation we want to be
128+ constrained.
129+ :rtype: None.
130+ '''
131+ hyper_param = self .fun_mask .copy ().astype (float )
132+ hyper_param = hyper_param / np .sum (hyper_param ,axis = 1 )
133+ self ._adjust_control_points_inner (src_pts ,hyper_param )
120134 def ffd (self , src_pts ):
121135 '''
122136 Performs Classic Free Form Deformation.
0 commit comments