A collection of the Kabsch (SVD-based) and Horn (Quaternion-based) optimal structural alignment algorithms, implemented natively across five Python math frameworks:
- 🐍 NumPy
- 🔥 PyTorch
- 🌌 JAX
- 🧱 TensorFlow
- 🍎 MLX
pip install kabsch-horn-cookbookOr with uv:
uv add kabsch-horn-cookbookAlternatively, copy the framework folder you need from src/kabsch_horn/<framework>/ directly into your project. The code has no runtime dependencies beyond the framework itself, so nothing new gets added to your environment.
src/kabsch_horn/
├── numpy/
├── pytorch/
├── jax/
├── tensorflow/
└── mlx/
Each folder contains two files: kabsch_svd_nd.py for SVD-based alignment and horn_quat_3d.py for quaternion-based alignment. Copy one file, both, or the whole folder -- the MIT license lets you borrow, modify, and redistribute freely.
Each framework exports a consistent API. Batch processing ([..., N, D] with arbitrary leading dims) is supported for all functions.
import torch
from kabsch_horn import pytorch as kh
# Replace `pytorch` with `jax`, `tensorflow`, `mlx`, or `numpy` as needed.
# 1. N-Dimensional SVD Kabsch
# N-Dimensional points (e.g., representation matching in 64D)
P_nd = torch.randn(10, 100, 64)
Q_nd = torch.randn(10, 100, 64)
# R: (Batch, 64, 64) | t: (Batch, 64) | rmsd: (Batch,)
R, t, rmsd = kh.kabsch(P_nd, Q_nd)
# Umeyama variant (with global scale)
# R: (Batch, 64, 64) | t: (Batch, 64) | c: (Batch,) | rmsd: (Batch,)
R, t, c, rmsd = kh.kabsch_umeyama(P_nd, Q_nd)
# 2. 3D Closed-Form Quaternion Horn
# 3D points (e.g., standard molecular/physics alignment)
P_3d = torch.randn(10, 100, 3)
Q_3d = torch.randn(10, 100, 3)
# R: (Batch, 3, 3) | t: (Batch, 3) | rmsd: (Batch,)
R, t, rmsd = kh.horn(P_3d, Q_3d)
# Horn with scale
# R: (Batch, 3, 3) | t: (Batch, 3) | c: (Batch,) | rmsd: (Batch,)
R, t, c, rmsd = kh.horn_with_scale(P_3d, Q_3d)
# Single-call RMSD loss (autodiff frameworks only; gradients remain stable)
loss = kh.kabsch_rmsd(P_nd, Q_nd)
loss.mean().backward()
# Per-point weights (e.g., confidence scores, B-factors)
weights = torch.rand(10, 100) # shape: [Batch, Points]
R, t, rmsd = kh.kabsch(P_nd, Q_nd, weights=weights)| Function | NumPy | PyTorch | JAX | TensorFlow | MLX |
|---|---|---|---|---|---|
kabsch |
✓ | ✓ | ✓ | ✓ | ✓ (3D only) |
kabsch_umeyama |
✓ | ✓ | ✓ | ✓ | ✓ (3D only) |
horn |
✓ | ✓ | ✓ | ✓ | ✓ |
horn_with_scale |
✓ | ✓ | ✓ | ✓ | ✓ |
kabsch_rmsd |
-- | ✓ | ✓ | ✓ | ✓ |
kabsch_umeyama_rmsd |
-- | ✓ | ✓ | ✓ | ✓ |
| Gradient-safe backward | -- | ✓ | ✓ | ✓ | ✓ |
NumPy provides forward-pass evaluation only. MLX uses a hardcoded 3x3 determinant correction and raises ValueError for non-3D inputs. JAX float64 requires JAX_ENABLE_X64=True to be set before importing JAX, otherwise inputs are silently downcast to float32.
All functions are compatible with torch.compile and jax.jit. Wrapping is optional -- functions work correctly without it -- but can improve throughput for repeated calls.
PyTorch (torch.compile):
import torch
from kabsch_horn import pytorch as kh
compiled_kabsch = torch.compile(kh.kabsch)
R, t, rmsd = compiled_kabsch(P, Q)The custom autograd functions (SafeSVD, SafeEigh) act as graph breaks under torch.compile, so the compiler cannot fuse operations across the SVD/eigh boundary. Surrounding code is still compiled and optimized.
JAX (jax.jit):
import jax
from kabsch_horn import jax as kh
jitted_kabsch = jax.jit(kh.kabsch)
R, t, rmsd = jitted_kabsch(P, Q)JAX float64 requires JAX_ENABLE_X64=True to be set before importing JAX, otherwise inputs are silently downcast to float32. This applies whether or not you use jit.
Traditionally used for 3D coordinates, this SVD implementation supports N-dimensional alignments. It scales to higher dimensions for tasks like mapping internal representations between models.
Horn's method applies strictly to 3D space. It uses a closed-form quaternion eigendecomposition to compute alignment. The quaternion formulation inherently produces proper rotations (
Point cloud alignments evaluated during neural network training often encounter mathematically degenerate states. For example, point clouds with perfect symmetry produce identical eigenvalues or singular values. Standard library gradients derived from the backward pass divide by these numerical differences, resulting in zero-division and NaN weights.
This cookbook addresses this directly. The autograd wrappers for PyTorch, JAX, TensorFlow, and MLX override their standard SVD and Eigh computational graphs, dynamically masking identical roots during backpropagation with epsilon factors.
Gradient stability is verified by the test suite. Hypothesis property tests confirm that gradients remain finite across coplanar, collinear, reflected, and collapsed inputs. A dedicated descent-direction test confirms that SafeSVD's masked gradients at near-degenerate inputs still reduce RMSD in a gradient step -- gradient accuracy against finite differences is also verified at float32 and float64. See tests/test_differentiability_traps.py and tests/test_gradient_verification.py.
The following mathematical properties are validated by property-based tests using Hypothesis. Each claim links to the test that justifies it.
These hold for all frameworks, all precisions, and all valid input shapes.
| Property | Algorithms | Test |
|---|---|---|
|
|
kabsch, horn | test_rotation_is_orthogonal_* |
|
|
kabsch, horn | test_rotation_det_is_positive_* |
| all | test_rmsd_is_nonnegative |
|
|
|
kabsch_umeyama, horn_with_scale | test_scale_is_positive_* |
These are verified with NumPy over Hypothesis-drawn inputs.
| Property | Test |
|---|---|
test_rmsd_equals_transform_residual |
|
|
|
test_no_rotation_achieves_lower_rmsd |
test_kabsch_rmsd_is_symmetric |
|
|
|
test_rmsd_invariant_to_rigid_transform |
|
|
test_r_invariant_to_translation |
|
|
test_r_invariant_to_uniform_scale |
| When |
test_umeyama_equals_kabsch_when_no_scale_change |
| When |
test_umeyama_recovers_exact_scale |
When the cross-covariance
| Property | Test |
|---|---|
kabsch and horn return identical |
test_kabsch_and_horn_agree_on_rotation_3d |
kabsch_umeyama and horn_with_scale agree in 3D |
test_umeyama_and_horn_with_scale_agree_3d |
SafeSVD and SafeEigh override the standard backward pass to mask near-zero singular value and eigenvalue differences with finfo(dtype).eps. The table below lists the degenerate cases explicitly tested.
| Degenerate input | Guarantee | Test |
|---|---|---|
|
|
Finite gradient | test_gradients_are_stable_when_points_are_identical |
| Coplanar points | Finite gradient | test_gradients_are_stable_when_points_are_coplanar |
| Collinear points | Finite gradient + descent direction | test_gradients_are_stable_when_points_are_collinear |
| Near-collinear, |
Finite gradient | test_gradients_stable_nearly_collinear_hypothesis |
| Near-collinear, |
Finite gradient | test_gradients_stable_nearly_collinear_different_clouds |
| Near-coplanar (Hypothesis, |
Finite gradient | test_gradients_stable_nearly_coplanar_hypothesis |
| Reflection (improper |
Finite gradient + |
test_gradients_are_stable_when_points_are_reflected |
| Underdetermined ( |
Finite gradient | test_gradients_are_stable_when_system_is_underdetermined |
| Collapse to origin | Finite gradient | test_gradients_are_stable_when_points_collapse_to_origin |
| Near-collinear or coplanar (Hypothesis, descent) | test_safe_svd_gradient_reduces_rmsd_at_hypothesis_near_degenerate |
"Descent direction" means one gradient step with test_gradients_match_finite_differences_when_perturbed and test_gradients_match_finite_differences_when_purely_random. Hypothesis-varied FD checks run at float64 only, where the tolerance is tight enough to be meaningful.
Some inputs are fundamentally degenerate. The library does not raise errors in these cases, but users should understand the implications.
Near-collinear clouds -- rotation is ambiguous. When test_rotation_is_not_unique_when_cross_covariance_is_degenerate.
MLX: 3D inputs only. MLX uses a hardcoded 3x3 determinant correction and raises ValueError for dim != 3.
NumPy: forward pass only. NumPy provides no autograd wrappers and does not export kabsch_rmsd or kabsch_umeyama_rmsd.
float16 / bfloat16: variance division can overflow. kabsch_umeyama and horn_with_scale divide by the point cloud variance. This overflows in half precision when inputs are near-collinear or collapsed to the origin. For production half-precision training, cast inputs to float32 before calling alignment functions.
float16 / bfloat16: accuracy is limited. Half-precision forward passes are tested with atol=0.1 / rtol=0.1. Deterministic finite-difference gradient checks are skipped for float16/bfloat16 because the effective tolerance (atol * 50 = 5.0) is too loose to be meaningful. For training stability, prefer float32 or higher.
MLX float64 runs on CPU. Apple Silicon GPUs do not support true float64, so float64 ops are automatically routed to CPU. float32 and half-precision inputs use GPU acceleration as normal.
JAX: double backward through kabsch/kabsch_umeyama is unsupported. JAX's custom_vjp does not implement an SVD JVP, so jax.grad(jax.grad(f)) through the Kabsch code path raises NotImplementedError upstream. Horn and horn_with_scale (eigh-based) support double backward in JAX without issue. TensorFlow, MLX, and PyTorch support double backward for all algorithms.
MLX: NaN inputs abort the process. mlx.linalg.svd fatally terminates the process when given NaN inputs. Every other framework propagates NaN gracefully. Validate inputs before passing them to MLX alignment functions if NaN is possible in your pipeline.
Each framework's kabsch_rmsd and kabsch_umeyama_rmsd functions are the simplest entry point for gradient-based training. For more complex losses, call kabsch or horn directly and operate on the returned R, t, and rmsd tensors:
from kabsch_horn import pytorch as kh
def contrastive_alignment_loss(P_pos, Q_pos, P_neg, Q_neg):
rmsd_pos = kh.kabsch_rmsd(P_pos, Q_pos)
rmsd_neg = kh.kabsch_rmsd(P_neg, Q_neg)
return (rmsd_pos - rmsd_neg + margin).clamp(min=0).mean()The rotation matrix R returned by kabsch and horn is differentiable, so it can be composed into downstream losses (e.g., point-to-point error after applying a learned perturbation on top of R).
To port these algorithms to a new backend, implement the following interface:
safe_svd(A)-- A custom-gradient SVD that masks near-zero singular value differences in the backward pass withfinfo(dtype).eps. Seesrc/kabsch_horn/pytorch/kabsch_svd_nd.py(SafeSVD) for the reference implementation.safe_eigh(A)-- Same pattern for eigendecomposition, used by Horn's method. SeeSafeEighinsrc/kabsch_horn/pytorch/horn_quat_3d.py.kabsch(P, Q)-- Accepts[N, D]or[..., N, D]inputs and returns(R, t, rmsd).horn(P, Q)-- Accepts[N, 3]or[..., N, 3]inputs and returns(R, t, rmsd).
The NumPy module (src/kabsch_horn/numpy/) is a clean forward-pass-only reference with no autograd dependencies, useful as a starting point.
The test suite is organized around mathematical claims rather than code coverage. Each test file targets a distinct category of properties.
| File | What it proves |
|---|---|
tests/test_forward_pass_equivalence.py |
Identical outputs across all frameworks and precisions for the same input; correct batching across [..., N, D] shapes |
tests/test_properties.py |
Output invariants (orthogonality, det=+1, RMSD >= 0), correctness invariants (RMSD definition, optimality, symmetry, rigid-transform invariance), and cross-algorithm consistency (kabsch = horn in 3D) |
tests/test_differentiability_traps.py |
Gradient finiteness across all documented degenerate cases; descent direction at singularities |
tests/test_gradient_verification.py |
Analytic gradients match finite differences (deterministic inputs at float32 + float64; Hypothesis-varied inputs at float64 only); batched gradients match sequential; SafeSVD descent at near-degenerate inputs; double backward (PyTorch, TensorFlow, MLX: all algorithms; JAX: Horn only) |
tests/test_degeneracy.py |
Forward-pass validity under extreme degeneracy (origin collapse, collinear, coplanar, underdetermined) |
tests/test_catastrophic_cancellation.py |
Numerical stability at extreme coordinate magnitudes (1e-6 to 1e6) |
tests/test_error_handling.py |
Correct exceptions for mismatched shapes, wrong dimensions, and invalid inputs |
tests/test_rmsd_wrappers.py |
kabsch_rmsd and kabsch_umeyama_rmsd match full-call RMSD output; N=1 single-point edge cases |
tests/test_reference_validation.py |
Cross-framework validation against NumPy reference outputs over multiple seeds |
tests/test_mixed_dtype.py |
Correct behavior when P and Q have different dtypes |
tests/test_mlx_float64_warning.py |
MLX emits a warning when float64 silently falls back to CPU |
tests/test_tf_dynamic_validation.py |
TensorFlow runtime shape validation for dynamic shapes |
tests/test_weighted.py |
Per-point weighted alignment: uniform equivalence, outlier downweighting, gradient stability, error handling, and batching |
The suite runs across 4 frameworks x 4 precisions (float16, bfloat16, float32, float64), with MLX restricted to 3D. Hypothesis property tests use configurable example counts; CI runs the defaults.
Run the test suite with:
uv run pytest tests/- [Kabsch 1976] Kabsch, W. (1976). "A solution for the best rotation to relate two sets of vectors."
- [Kabsch 1978] Kabsch, W. (1978). "A discussion of the solution for the best rotation to relate two sets of vectors."
- [Horn 1987] Horn, B.K.P. (1987). "Closed-form solution of absolute orientation using unit quaternions."
- [Umeyama 1991] Umeyama, S. (1991). "Least-squares estimation of transformation parameters between two point patterns."