Skip to content

Commit de60a57

Browse files
started refactoring and deletion of two useless lines
1 parent 1b0f8f2 commit de60a57

File tree

6 files changed

+152
-160
lines changed

6 files changed

+152
-160
lines changed

pygem/bffd.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,23 +56,11 @@ class BFFD(CFFD):
5656
'''
5757
def __init__(self, n_control_points=None):
5858
super().__init__(n_control_points)
59+
5960
def linfun(x):
60-
return np.mean(x.reshape(-1,3),axis=0)
61+
return np.mean(x.reshape(-1, 3), axis=0)
6162

62-
self.linconstraint=linfun
63+
self.fun = linfun
6364

6465
def __call__(self, src_pts):
6566
return super().__call__(src_pts)
66-
67-
if __name__ == "__main__":
68-
from pygem import BFFD
69-
import numpy as np
70-
bffd = BFFD()
71-
bffd.read_parameters('tests/test_datasets/parameters_test_ffd_sphere.prm')
72-
original_mesh_points = np.load('tests/test_datasets/meshpoints_sphere_orig.npy')
73-
b=bffd.linconstraint(original_mesh_points)
74-
bffd.valconstraint=b
75-
bffd.indices=np.arange(np.prod(bffd.n_control_points)*3).tolist()
76-
bffd.M=np.eye(len(bffd.indices))
77-
new_mesh_points = bffd(original_mesh_points)
78-
assert np.isclose(np.linalg.norm(bffd.linconstraint(new_mesh_points)-b),np.array([0.]))

pygem/cffd.py

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ class CFFD(FFD):
3434
y, normalized with the box length y.
3535
:cvar numpy.ndarray array_mu_z: collects the displacements (weights) along
3636
z, normalized with the box length z.
37-
:cvar callable linconstraint: it defines the F of the constraint F(x)=c.
38-
:cvar numpy.ndarray valconstraint: it defines the c of the constraint F(x)=c.
37+
:cvar callable fun: it defines the F of the constraint F(x)=c.
38+
:cvar numpy.ndarray fixval: it defines the c of the constraint F(x)=c.
3939
:cvar list indices: it defines the indices of the control points
4040
that are moved to enforce the constraint. The control index is obtained by doing:
4141
all_indices=np.arange(n_x*n_y*n_z*3).reshape(n_x,n_y,n_z,3).tolist().
@@ -53,31 +53,31 @@ class CFFD(FFD):
5353
>>> x=x.reshape(-1)
5454
>>> return A@x
5555
>>> b=fun(original_mesh_points)
56-
>>> cffd.linconstraint=fun
57-
>>> cffd.valconstraint=b
56+
>>> cffd.fun=fun
57+
>>> cffd.fixval=b
5858
>>> cffd.indices=np.arange(np.prod(cffd.n_control_points)*3).tolist()
5959
>>> cffd.M=np.eye(len(cffd.indices))
6060
>>> new_mesh_points = cffd(original_mesh_points)
6161
>>> assert np.isclose(np.linalg.norm(fun(new_mesh_points)-b),np.array([0.]))
6262
"""
63-
6463
def __init__(self, n_control_points=None):
6564
super().__init__(n_control_points)
66-
self.linconstraint = None
67-
self.valconstraint = None
65+
self.fun = None
66+
self.fixval = None
6867
self.indices = None
6968
self.M = None
7069

7170
def __call__(self, src_pts):
72-
saved_parameters=self._save_parameters()
71+
saved_parameters = self._save_parameters()
7372
A, b = self._compute_linear_map(src_pts, saved_parameters.copy())
7473
d = A @ saved_parameters[self.indices] + b
75-
deltax = np.linalg.inv(self.M) @ A.T @ np.linalg.inv((A @ np.linalg.inv(self.M)@ A.T)) @ (self.valconstraint - d)
74+
deltax = np.linalg.inv(self.M) @ A.T @ np.linalg.inv(
75+
(A @ np.linalg.inv(self.M) @ A.T)) @ (self.fixval - d)
7676
saved_parameters[self.indices] = saved_parameters[self.indices] + deltax
7777
self._load_parameters(saved_parameters)
7878
return self.ffd(src_pts)
7979

80-
def ffd(self,src_pts):
80+
def ffd(self, src_pts):
8181
'''
8282
Performs Classic Free Form Deformation.
8383
@@ -87,35 +87,35 @@ def ffd(self,src_pts):
8787
'''
8888
return super().__call__(src_pts)
8989

90-
9190
def _save_parameters(self):
9291
'''
9392
Saves the FFD control points in an array of shape [n_x,ny,nz,3].
9493
9594
:return: the FFD control points in an array of shape [n_x,ny,nz,3].
9695
:rtype: numpy.ndarray
9796
'''
98-
tmp = np.zeros([*self.n_control_points,3])
97+
tmp = np.zeros([*self.n_control_points, 3])
9998
tmp[:, :, :, 0] = self.array_mu_x
10099
tmp[:, :, :, 1] = self.array_mu_y
101100
tmp[:, :, :, 2] = self.array_mu_z
102101
return tmp.reshape(-1)
103-
104-
def _load_parameters(self,tmp):
102+
103+
def _load_parameters(self, tmp):
105104
'''
106105
Loads the FFD control points from an array of shape [n_x,ny,nz,3].
107106
108107
:param np.ndarray tmp: the array of FFD control points.
109108
:rtype: None
110109
'''
111-
tmp = tmp.reshape(*self.n_control_points,3)
110+
tmp = tmp.reshape(*self.n_control_points, 3)
112111
self.array_mu_x = tmp[:, :, :, 0]
113112
self.array_mu_y = tmp[:, :, :, 1]
114113
self.array_mu_z = tmp[:, :, :, 2]
115114

116115

117116
# I see that a similar function already exists in pygem.utils, but it does not work for inputs and outputs of different dimensions
118-
def _compute_linear_map(self,src_pts,saved_parameters):
117+
118+
def _compute_linear_map(self, src_pts, saved_parameters):
119119
'''
120120
Computes the coefficient and the intercept of the linear map from the control points to the output.
121121
@@ -124,21 +124,25 @@ def _compute_linear_map(self,src_pts,saved_parameters):
124124
:return: a tuple containing the coefficient and the intercept.
125125
:rtype: tuple(np.ndarray,np.ndarray)
126126
'''
127-
saved_parameters_bak=saved_parameters.copy() #saving ffd parameters
128-
n_indices=len(self.indices)
129-
inputs=np.zeros([n_indices+1,n_indices+1])
130-
outputs=np.zeros([n_indices+1,self.valconstraint.shape[0]])
127+
n_indices = len(self.indices)
128+
inputs = np.zeros([n_indices + 1, n_indices + 1])
129+
outputs = np.zeros([n_indices + 1, self.fixval.shape[0]])
131130
np.random.seed(0)
132-
for i in range(n_indices+1): ##now we generate the interpolation points
133-
tmp=np.random.rand(1,n_indices)
134-
tmp=tmp.reshape(1,-1)
135-
inputs[i]=np.hstack([tmp, np.ones((tmp.shape[0], 1))]) #dependent variable
136-
saved_parameters[self.indices]=tmp
137-
self._load_parameters(saved_parameters) #loading the depent variable as a control point
138-
def_pts=super().__call__(src_pts) #computing the deformation with the dependent variable
139-
outputs[i]=self.linconstraint(def_pts) #computing the independent variable
140-
sol=np.linalg.lstsq(inputs,outputs,rcond=None) #computation of the linear map
141-
A=sol[0].T[:,:-1] #coefficient
142-
b=sol[0].T[:,-1] #intercept
143-
self._load_parameters(saved_parameters_bak) #restoring the original FFD parameters
144-
return A,b
131+
for i in range(n_indices +
132+
1): ##now we generate the interpolation points
133+
tmp = np.random.rand(1, n_indices)
134+
tmp = tmp.reshape(1, -1)
135+
inputs[i] = np.hstack([tmp, np.ones(
136+
(tmp.shape[0], 1))]) #dependent variable
137+
saved_parameters[self.indices] = tmp
138+
self._load_parameters(
139+
saved_parameters
140+
) #loading the depent variable as a control point
141+
def_pts = super().__call__(
142+
src_pts) #computing the deformation with the dependent variable
143+
outputs[i] = self.fun(def_pts) #computing the independent variable
144+
sol = np.linalg.lstsq(inputs, outputs,
145+
rcond=None) #computation of the linear map
146+
A = sol[0].T[:, :-1] #coefficient
147+
b = sol[0].T[:, -1] #intercept
148+
return A, b

pygem/vffd.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import numpy as np
1313
from copy import deepcopy
1414

15+
1516
class VFFD(CFFD):
1617
'''
1718
Class that handles the Volumetric Free Form Deformation on the mesh points.
@@ -58,32 +59,33 @@ class VFFD(CFFD):
5859
>>> new_mesh_points = vffd(original_mesh_points)
5960
>>> assert np.isclose(np.linalg.norm(vffd.linconstraint(new_mesh_points)-b),np.array([0.]))
6061
'''
61-
62-
def __init__(self,triangles ,n_control_points=None):
62+
def __init__(self, triangles, n_control_points=None):
6363
super().__init__(n_control_points)
64-
self.triangles=triangles
65-
self.vweight=[1/3,1/3,1/3]
64+
self.triangles = triangles
65+
self.vweight = [1 / 3, 1 / 3, 1 / 3]
66+
6667
def volume(x):
67-
x=x.reshape(-1,3)
68-
mesh=x[self.triangles]
69-
return np.sum(np.linalg.det(mesh))
70-
self.linconstraint=volume
68+
x = x.reshape(-1, 3)
69+
mesh = x[self.triangles]
70+
return np.sum(np.linalg.det(mesh))
7171

72+
self.fun = volume
7273

73-
def __call__(self,src_pts):
74-
self.vweight=np.abs(self.vweight)/np.sum(np.abs(self.vweight))
75-
indices_bak=deepcopy(self.indices)
76-
self.indices=np.array(self.indices)
77-
indices_x=self.indices[self.indices%3==0].tolist()
78-
indices_y=self.indices[self.indices%3==1].tolist()
79-
indices_z=self.indices[self.indices%3==2].tolist()
80-
indexes=[indices_x,indices_y,indices_z]
81-
diffvolume=self.valconstraint-self.linconstraint(self.ffd(src_pts))
74+
def __call__(self, src_pts):
75+
self.vweight = np.abs(self.vweight) / np.sum(np.abs(self.vweight))
76+
indices_bak = deepcopy(self.indices)
77+
self.indices = np.array(self.indices)
78+
indices_x = self.indices[self.indices % 3 == 0].tolist()
79+
indices_y = self.indices[self.indices % 3 == 1].tolist()
80+
indices_z = self.indices[self.indices % 3 == 2].tolist()
81+
indexes = [indices_x, indices_y, indices_z]
82+
diffvolume = self.fixval - self.fun(self.ffd(src_pts))
8283
for i in range(3):
83-
self.indices=indexes[i]
84-
self.M=np.eye(len(self.indices))
85-
self.valconstraint=self.linconstraint(self.ffd(src_pts))+self.vweight[i]*(diffvolume)
86-
_=super().__call__(src_pts)
87-
tmp=super().__call__(src_pts)
88-
self.indices=indices_bak
89-
return tmp
84+
self.indices = indexes[i]
85+
self.M = np.eye(len(self.indices))
86+
self.fixval = self.fun(
87+
self.ffd(src_pts)) + self.vweight[i] * (diffvolume)
88+
_ = super().__call__(src_pts)
89+
tmp = super().__call__(src_pts)
90+
self.indices = indices_bak
91+
return tmp

tests/test_bffd.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@ class TestBFFD(TestCase):
99
def test_nothing_happens(self):
1010
np.random.seed(0)
1111
cffd = BFFD()
12-
original_mesh_points = np.load("tests/test_datasets/meshpoints_sphere_orig.npy")
12+
original_mesh_points = np.load(
13+
"tests/test_datasets/meshpoints_sphere_orig.npy")
1314
A = np.random.rand(3, original_mesh_points.reshape(-1).shape[0])
14-
b = cffd.linconstraint(original_mesh_points)
15-
cffd.valconstraint = b
15+
b = cffd.fun(original_mesh_points)
16+
cffd.fixval = b
1617
cffd.indices = np.arange(np.prod(cffd.n_control_points) * 3).tolist()
1718
cffd.M = np.eye(len(cffd.indices))
1819
new_mesh_points = cffd(original_mesh_points)
@@ -21,13 +22,14 @@ def test_nothing_happens(self):
2122
def test_constraint(self):
2223
np.random.seed(0)
2324
cffd = BFFD()
24-
cffd.read_parameters("tests/test_datasets/parameters_test_ffd_sphere.prm")
25-
original_mesh_points = np.load("tests/test_datasets/meshpoints_sphere_orig.npy")
26-
b = cffd.linconstraint(original_mesh_points)
27-
cffd.valconstraint = b
25+
cffd.read_parameters(
26+
"tests/test_datasets/parameters_test_ffd_sphere.prm")
27+
original_mesh_points = np.load(
28+
"tests/test_datasets/meshpoints_sphere_orig.npy")
29+
b = cffd.fun(original_mesh_points)
30+
cffd.fixval = b
2831
cffd.indices = np.arange(np.prod(cffd.n_control_points) * 3).tolist()
2932
cffd.M = np.eye(len(cffd.indices))
3033
new_mesh_points = cffd(original_mesh_points)
31-
assert np.isclose(
32-
np.linalg.norm(cffd.linconstraint(new_mesh_points) - b), np.array([0.0])
33-
)
34+
assert np.isclose(np.linalg.norm(cffd.fun(new_mesh_points) - b),
35+
np.array([0.0]))

tests/test_cffd.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,17 @@ class TestCFFD(TestCase):
99
def test_nothing_happens(self):
1010
np.random.seed(0)
1111
cffd = CFFD()
12-
original_mesh_points = np.load("tests/test_datasets/meshpoints_sphere_orig.npy")
12+
original_mesh_points = np.load(
13+
"tests/test_datasets/meshpoints_sphere_orig.npy")
1314
A = np.random.rand(3, original_mesh_points.reshape(-1).shape[0])
1415

1516
def fun(x):
1617
x = x.reshape(-1)
1718
return A @ x
1819

1920
b = fun(original_mesh_points)
20-
cffd.linconstraint = fun
21-
cffd.valconstraint = b
21+
cffd.fun = fun
22+
cffd.fixval = b
2223
cffd.indices = np.arange(np.prod(cffd.n_control_points) * 3).tolist()
2324
cffd.M = np.eye(len(cffd.indices))
2425
new_mesh_points = cffd(original_mesh_points)
@@ -27,34 +28,38 @@ def fun(x):
2728
def test_constraint(self):
2829
np.random.seed(0)
2930
cffd = CFFD()
30-
cffd.read_parameters("tests/test_datasets/parameters_test_ffd_sphere.prm")
31-
original_mesh_points = np.load("tests/test_datasets/meshpoints_sphere_orig.npy")
31+
cffd.read_parameters(
32+
"tests/test_datasets/parameters_test_ffd_sphere.prm")
33+
original_mesh_points = np.load(
34+
"tests/test_datasets/meshpoints_sphere_orig.npy")
3235
A = np.random.rand(3, original_mesh_points.reshape(-1).shape[0])
3336

3437
def fun(x):
3538
x = x.reshape(-1)
3639
return A @ x
3740

3841
b = fun(original_mesh_points)
39-
cffd.linconstraint = fun
40-
cffd.valconstraint = b
42+
cffd.fun = fun
43+
cffd.fixval = b
4144
cffd.indices = np.arange(np.prod(cffd.n_control_points) * 3).tolist()
4245
cffd.M = np.eye(len(cffd.indices))
4346
new_mesh_points = cffd(original_mesh_points)
44-
assert np.isclose(np.linalg.norm(fun(new_mesh_points) - b), np.array([0.0]))
47+
assert np.isclose(np.linalg.norm(fun(new_mesh_points) - b),
48+
np.array([0.0]))
4549

4650
def test_interpolation(self):
4751
cffd = CFFD()
48-
original_mesh_points = np.load("tests/test_datasets/meshpoints_sphere_orig.npy")
52+
original_mesh_points = np.load(
53+
"tests/test_datasets/meshpoints_sphere_orig.npy")
4954
A = np.random.rand(3, original_mesh_points.reshape(-1).shape[0])
5055

5156
def fun(x):
5257
x = x.reshape(-1)
5358
return A @ x
5459

5560
b = fun(original_mesh_points)
56-
cffd.linconstraint = fun
57-
cffd.valconstraint = b
61+
cffd.fixval = b
62+
cffd.fun = fun
5863
cffd.indices = np.arange(np.prod(cffd.n_control_points) * 3).tolist()
5964
cffd.M = np.eye(len(cffd.indices))
6065
save_par = cffd._save_parameters()
@@ -64,8 +69,7 @@ def fun(x):
6469
save_par[cffd.indices] = tmp
6570
cffd._load_parameters(tmp)
6671
assert np.allclose(
67-
np.linalg.norm(
68-
C @ tmp + d - cffd.linconstraint(cffd.ffd(original_mesh_points))
69-
),
72+
np.linalg.norm(C @ tmp + d -
73+
cffd.fun(cffd.ffd(original_mesh_points))),
7074
0.0,
7175
)

0 commit comments

Comments
 (0)