@@ -60,20 +60,37 @@ class CFFD(FFD):
6060 >>> new_mesh_points = cffd(original_mesh_points)
6161 >>> assert np.isclose(np.linalg.norm(fun(new_mesh_points)-b),np.array([0.]))
6262 """
63- def __init__ (self , n_control_points = None ):
63+ def __init__ (self , n_control_points = None , fun = None , fixval = None , M = None , mask = None ):
6464 super ().__init__ (n_control_points )
65- self .fun = None
66- self .fixval = None
67- self .indices = None
68- self .M = None
65+
66+ if mask == None :
67+ self .mask = np .full ((* self .n_control_points ,3 ), True , dtype = bool )
68+ else :
69+ self .mask = mask
70+
71+ if fixval == None :
72+ self .fixval = np .array ([1. ])
73+ else :
74+ self .fixval = fixval
75+
76+ if fun == None :
77+ self .fun = lambda x : self .fixval
78+
79+ else :
80+ self .fun = fun
81+
82+ if M == None :
83+ self .M = np .eye (np .sum (self .mask .astype (int )))
6984
7085 def __call__ (self , src_pts ):
7186 saved_parameters = self ._save_parameters ()
72- A , b = self ._compute_linear_map (src_pts , saved_parameters .copy ())
73- d = A @ saved_parameters [self .indices ] + b
74- deltax = np .linalg .inv (self .M ) @ A .T @ np .linalg .inv (
75- (A @ np .linalg .inv (self .M ) @ A .T )) @ (self .fixval - d )
76- saved_parameters [self .indices ] = saved_parameters [self .indices ] + deltax
87+ indices = np .arange (np .prod (self .n_control_points )* 3 )[self .mask .reshape (- 1 )]
88+ A , b = self ._compute_linear_map (src_pts , saved_parameters .copy (),indices )
89+ d = A @ saved_parameters [indices ] + b
90+ invM = np .linalg .inv (self .M )
91+ deltax = invM @ A .T @ np .linalg .inv (
92+ (A @ invM @ A .T )) @ (self .fixval - d )
93+ saved_parameters [indices ] = saved_parameters [indices ] + deltax
7794 self ._load_parameters (saved_parameters )
7895 return self .ffd (src_pts )
7996
@@ -112,10 +129,14 @@ def _load_parameters(self, tmp):
112129 self .array_mu_y = tmp [:, :, :, 1 ]
113130 self .array_mu_z = tmp [:, :, :, 2 ]
114131
132+ def read_parameters (self ,filename = 'parameters.prm' ):
133+ super ().read_parameters (filename )
134+ self .mask = np .full ((* self .n_control_points ,3 ), True , dtype = bool )
135+ self .M = np .eye (np .sum (self .mask .astype (int )))
115136
116137# I see that a similar function already exists in pygem.utils, but it does not work for inputs and outputs of different dimensions
117138
118- def _compute_linear_map (self , src_pts , saved_parameters ):
139+ def _compute_linear_map (self , src_pts , saved_parameters , indices ):
119140 '''
120141 Computes the coefficient and the intercept of the linear map from the control points to the output.
121142
@@ -124,7 +145,7 @@ def _compute_linear_map(self, src_pts, saved_parameters):
124145 :return: a tuple containing the coefficient and the intercept.
125146 :rtype: tuple(np.ndarray,np.ndarray)
126147 '''
127- n_indices = len (self . indices )
148+ n_indices = len (indices )
128149 inputs = np .zeros ([n_indices + 1 , n_indices + 1 ])
129150 outputs = np .zeros ([n_indices + 1 , self .fixval .shape [0 ]])
130151 np .random .seed (0 )
@@ -134,7 +155,7 @@ def _compute_linear_map(self, src_pts, saved_parameters):
134155 tmp = tmp .reshape (1 , - 1 )
135156 inputs [i ] = np .hstack ([tmp , np .ones (
136157 (tmp .shape [0 ], 1 ))]) #dependent variable
137- saved_parameters [self . indices ] = tmp
158+ saved_parameters [indices ] = tmp
138159 self ._load_parameters (
139160 saved_parameters
140161 ) #loading the depent variable as a control point
@@ -146,3 +167,22 @@ def _compute_linear_map(self, src_pts, saved_parameters):
146167 A = sol [0 ].T [:, :- 1 ] #coefficient
147168 b = sol [0 ].T [:, - 1 ] #intercept
148169 return A , b
170+
171+ np .random .seed (0 )
172+ cffd = CFFD ()
173+ cffd .read_parameters (
174+ "tests/test_datasets/parameters_test_ffd_sphere.prm" )
175+ original_mesh_points = np .load (
176+ "tests/test_datasets/meshpoints_sphere_orig.npy" )
177+ A = np .random .rand (3 , original_mesh_points .reshape (- 1 ).shape [0 ])
178+
179+ def fun (x ):
180+ x = x .reshape (- 1 )
181+ return A @ x
182+
183+ b = fun (original_mesh_points )
184+ cffd .fun = fun
185+ cffd .fixval = b
186+ new_mesh_points = cffd (original_mesh_points )
187+ assert np .isclose (np .linalg .norm (fun (new_mesh_points ) - b ),
188+ np .array ([0.0 ]))
0 commit comments