Skip to content

Commit 92cc306

Browse files
started working on hyperparameter optimization. Weight matrix is back.
1 parent bdf74cb commit 92cc306

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

pygem/cffd.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ class CFFD(FFD):
4242
:cvar numpy.ndarray fun_mask: a boolean tensor that tells to the class
4343
on which axis which constraint depends on. The tensor has shape (n_cons,3), where the last dimension indicates dependency on
4444
on x,y,z respectively. Default is all true. It used only in the triaffine mode.
45-
45+
:cvar numpy.ndarray weight_matrix: a symmetric positive definite weigth matrix.
46+
It must be of row and column size the number of trues in the mask.
47+
It weights the movemement of the control points which have a true flag in the ffd_mask.
48+
Default is identity.
4649
4750
:Example:
4851
@@ -64,7 +67,7 @@ def __init__(self,
6467
fun,
6568
n_control_points=None,
6669
ffd_mask=None,
67-
fun_mask=None):
70+
fun_mask=None, weight_matrix=None):
6871
super().__init__(n_control_points)
6972

7073
if ffd_mask is None:
@@ -79,22 +82,23 @@ def __init__(self,
7982
self.fun_mask = np.full((self.num_cons, 3), True, dtype=bool)
8083
else:
8184
self.fun_mask = fun_mask
82-
def adjust_control_points(self,src_pts):
85+
86+
if weight_matrix is None:
87+
self.weight_matrix = np.eye(np.sum(self.ffd_mask))
88+
def _adjust_control_points_inner(self,src_pts,hyper_param):
8389
'''
84-
Adjust the FFD control points such that fun(ffd(src_pts))=fixval
90+
Adjust the FFD control points such that fun(ffd(src_pts))=fixval given a hyperparameter.
8591
8692
:param np.ndarray src_pts: the points whose deformation we want to be
8793
constrained.
8894
:rtype: None.
8995
'''
90-
vweight=self.fun_mask.copy().astype(float)
91-
vweight=vweight/np.sum(vweight,axis=1)
9296
mask_bak = self.ffd_mask.copy()
9397
diffvolume = self.fixval - self.fun(self.ffd(src_pts))
9498
for i in range(3):
9599
self.ffd_mask = np.full((*self.n_control_points, 3), False, dtype=bool)
96100
self.ffd_mask[:, :, :, i] = mask_bak[:, :, :, i].copy()
97-
self.fixval = self.fun(self.ffd(src_pts)) + vweight[:,i] * (
101+
self.fixval = self.fun(self.ffd(src_pts)) + hyper_param[:,i] * (
98102
diffvolume
99103
)
100104
saved_parameters = self._save_parameters()
@@ -116,7 +120,17 @@ def adjust_control_points(self,src_pts):
116120
self.ffd_mask = mask_bak.copy()
117121

118122

119-
123+
def adjust_control_points(self,src_pts):
124+
'''
125+
Adjust the FFD control points such that fun(ffd(src_pts))=fixval
126+
127+
:param np.ndarray src_pts: the points whose deformation we want to be
128+
constrained.
129+
:rtype: None.
130+
'''
131+
hyper_param=self.fun_mask.copy().astype(float)
132+
hyper_param=hyper_param/np.sum(hyper_param,axis=1)
133+
self._adjust_control_points_inner(src_pts,hyper_param)
120134
def ffd(self, src_pts):
121135
'''
122136
Performs Classic Free Form Deformation.

0 commit comments

Comments
 (0)