Skip to content

Commit 7278e2e

Browse files
moved to multi_dot in main function
1 parent e501eae commit 7278e2e

File tree

2 files changed

+4
-21
lines changed

2 files changed

+4
-21
lines changed

pygem/cffd.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@ def __call__(self, src_pts):
8888
A, b = self._compute_linear_map(src_pts, saved_parameters.copy(),indices)
8989
d = A @ saved_parameters[indices] + b
9090
invM=np.linalg.inv(self.M)
91-
deltax = invM @ A.T @ np.linalg.inv(
92-
(A @ invM @ A.T)) @ (self.fixval - d)
91+
deltax = np.linalg.multi_dot([invM , A.T , np.linalg.inv((A @ invM @ A.T)) , (self.fixval - d)])
9392
saved_parameters[indices] = saved_parameters[indices] + deltax
9493
self._load_parameters(saved_parameters)
9594
return self.ffd(src_pts)
@@ -167,22 +166,5 @@ def _compute_linear_map(self, src_pts, saved_parameters,indices):
167166
A = sol[0].T[:, :-1] #coefficient
168167
b = sol[0].T[:, -1] #intercept
169168
return A, b
169+
170170

171-
np.random.seed(0)
172-
cffd = CFFD()
173-
cffd.read_parameters(
174-
"tests/test_datasets/parameters_test_ffd_sphere.prm")
175-
original_mesh_points = np.load(
176-
"tests/test_datasets/meshpoints_sphere_orig.npy")
177-
A = np.random.rand(3, original_mesh_points.reshape(-1).shape[0])
178-
179-
def fun(x):
180-
x = x.reshape(-1)
181-
return A @ x
182-
183-
b = fun(original_mesh_points)
184-
cffd.fun = fun
185-
cffd.fixval = b
186-
new_mesh_points = cffd(original_mesh_points)
187-
assert np.isclose(np.linalg.norm(fun(new_mesh_points) - b),
188-
np.array([0.0]))

tests/test_cffd.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ def fun(x):
4040
cffd.fun = fun
4141
cffd.fixval = b
4242
new_mesh_points = cffd(original_mesh_points)
43+
print(np.linalg.norm(fun(new_mesh_points) - b))
4344
assert np.isclose(np.linalg.norm(fun(new_mesh_points) - b),
44-
np.array([0.0]))
45+
np.array([0.0]),atol=1e-7)
4546

4647
def test_interpolation(self):
4748
cffd = CFFD()

0 commit comments

Comments
 (0)