Skip to content

Commit 34dfeb9

Browse files
committed
Fix kabsch alignment
1 parent 2f22dad commit 34dfeb9

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

utils/rmsd.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@ def kabsch_align(P, Q):
3030
H = jnp.dot(p.T, q)
3131

3232
# SVD
33-
U, S, Vt = jnp.linalg.svd(H)
33+
U, _, Vt = jnp.linalg.svd(H)
3434

3535
# Validate right-handed coordinate system
36-
Vt = jnp.where(jnp.linalg.det(jnp.dot(Vt.T, U.T)) < 0.0, -Vt, Vt)
36+
det = jnp.linalg.det(jnp.dot(Vt.T, U.T))
37+
Vt = jnp.where(det < 0.0, Vt.at[-1, :].set(Vt[-1, :] * -1.0), Vt)
3738

3839
# Optimal rotation
3940
R = jnp.dot(Vt.T, U.T)
@@ -42,6 +43,7 @@ def kabsch_align(P, Q):
4243

4344

4445
@jax.jit
46+
@jax.vmap
4547
def kabsch_rmsd(P, Q):
4648
P_aligned, Q_aligned = kabsch_align(P, Q)
4749
return jnp.sqrt(jnp.sum(jnp.square(P_aligned - Q_aligned)) / P.shape[0])

0 commit comments

Comments
 (0)