MLX-native port of PyManopt for Apple Silicon GPUs. Full API parity — all 15 real-valued manifolds, 5 optimizers, and identical class/parameter names.
+------------------+ +------------------+ +------------------+
| Cost Function | ----> | Problem | ----> | Optimizer |
| f: M -> R | | manifold + cost | | SD / CG / TR |
+------------------+ +------------------+ +------------------+
| | |
v v v
mx.grad (autodiff) Manifold constraint Riemannian updates
(projection, retraction) (geodesic steps)
Many ML problems have natural geometric constraints:
| Problem | Constraint | Manifold |
|---|---|---|
| Normalized embeddings | ||x|| = 1 | Sphere |
| PCA, subspace learning | Orthonormal basis | Stiefel |
| Low-rank matrix completion | Fixed rank | FixedRankEmbedded |
| Covariance estimation | Positive definite | SPD |
| Hyperbolic embeddings | Poincare ball | PoincareBall |
| Dictionary learning | Unit-norm atoms | Oblique |
Riemannian optimization respects these constraints by design, avoiding projection hacks and improving convergence.
# Requires Apple Silicon Mac with macOS 13+
pip install mlx
# Clone and install
git clone https://github.com/nborwankar/mlx-manopt.git
cd mlx-manopt
pip install -e .import mlx.core as mx
from mlx_manopt import Sphere, SteepestDescent, Problem
# Find the leading eigenvector of a matrix
n = 100
A = mx.random.normal((n, n))
A = (A + A.T) / 2 # Symmetric matrix
# Minimize -x^T A x on the unit sphere
def cost(x):
return -mx.sum(x * (A @ x))
problem = Problem(Sphere(n), cost)
result = SteepestDescent().run(problem)
print(f"Converged in {result.iterations} iterations")
print(f"Final cost: {result.cost:.6f}")All 15 real-valued manifolds from PyManopt, with identical class names and constructor parameters:
| Manifold | Description | Dimension |
|---|---|---|
Sphere(n) |
Unit sphere S^{n-1} in R^n | n-1 |
Stiefel(n, p) |
Orthonormal n x p matrices | np - p(p+1)/2 |
Grassmann(n, p) |
p-dimensional subspaces of R^n | p(n-p) |
Euclidean(*shape) |
Unconstrained R^n or R^{m x n} | product of dims |
Symmetric(n) |
n x n symmetric matrices | n(n+1)/2 |
SkewSymmetric(n) |
n x n skew-symmetric matrices | n(n-1)/2 |
Oblique(m, n) |
m x n matrices with unit-norm columns | (m-1) x n |
SymmetricPositiveDefinite(n) |
n x n SPD matrices | n(n+1)/2 |
PSDFixedRank(n, k) |
Rank-k PSD matrices via Y Y^T | kn - k(k-1)/2 |
Elliptope(n, k) |
Correlation matrices (PSD + diag=1) | n(k-1) - k(k-1)/2 |
Positive(m, n) |
Matrices with strictly positive entries | m x n |
PoincareBall(n) |
Poincare ball model of hyperbolic space | n |
SpecialOrthogonalGroup(n) |
Rotation matrices SO(n) | n(n-1)/2 |
FixedRankEmbedded(m, n, k) |
m x n matrices of fixed rank k | (m+n-k) x k |
Product([M1, M2, ...]) |
Cartesian product | sum of dims |
Variants: SphereSubspaceIntersection, SphereSubspaceComplementIntersection, multi-k versions of Stiefel/Grassmann/SPD.
| Optimizer | Type | Best For |
|---|---|---|
SteepestDescent |
1st order, gradient | Simple problems, debugging |
ConjugateGradient |
1st order, gradient | Large-scale, ill-conditioned problems |
TrustRegions |
2nd order, Hessian | High-accuracy solutions |
NelderMead |
Derivative-free, simplex | Non-smooth or noisy objectives |
ParticleSwarm |
Derivative-free, population | Global optimization, multimodal |
mlx-manopt matches PyManopt's public API exactly:
- Class names:
Sphere,Stiefel,ConjugateGradient, etc. — identical - Constructor parameters:
contraction_factor,beta_rule,rho_prime, etc. — identical - File layout:
manifolds/,optimizers/,tools/,core/— same structure - Problem signature:
Problem(manifold, cost, *, euclidean_gradient=None, riemannian_gradient=None, ...) - Backend: Registers as a PyManopt backend for interop (
src/backends/mlx_backend.py)
Existing PyManopt code can be ported by changing imports:
# Before (PyManopt + NumPy)
from pymanopt.manifolds import Sphere
from pymanopt.optimizers import SteepestDescent
# After (mlx-manopt + MLX)
from mlx_manopt import Sphere, SteepestDescentCross-validated: 60/60 same-input comparisons produce matching results (within float32 tolerance).
from mlx_manopt.tools.diagnostics import check_gradient, check_hessian, check_retraction
# Verify your cost function's gradient is correct
check_gradient(problem)
# Verify Hessian
check_hessian(problem)from mlx_manopt import Problem, Sphere
import mlx.core as mx
manifold = Sphere(100)
def cost(x):
return mx.sum(x ** 2)
# Gradient computed automatically via mx.grad
problem = Problem(manifold, cost)
# Or provide your own Riemannian gradient directly
problem = Problem(manifold, cost, riemannian_gradient=my_grad_fn)from mlx_manopt import SteepestDescent, ConjugateGradient, TrustRegions
# Configure optimizer
optimizer = ConjugateGradient(
max_iterations=1000,
min_gradient_norm=1e-6,
beta_rule="FletcherReeves", # or HagerZhang, HestenesStiefel, PolakRibiere, LiuStorey
verbosity=2,
)
# Run optimization
result = optimizer.run(problem, initial_point=x0)
# Access results
optimal_x = result.point
final_cost = result.cost
num_iters = result.iterationsfrom mlx_manopt.manifolds import Manifold
import mlx.core as mx
class MyManifold(Manifold):
def __init__(self, n):
super().__init__(name="My Manifold", dimension=n)
self._n = n
def inner_product(self, point, tangent_vector_a, tangent_vector_b):
return mx.sum(tangent_vector_a * tangent_vector_b)
def projection(self, point, vector):
return vector - mx.sum(vector * point) * point
def retraction(self, point, tangent_vector):
y = point + tangent_vector
return y / mx.linalg.norm(y)
def random_point(self):
x = mx.random.normal((self._n,))
return x / mx.linalg.norm(x)
def random_tangent_vector(self, point):
v = mx.random.normal(point.shape)
v = self.projection(point, v)
return v / mx.linalg.norm(v)
def zero_vector(self, point):
return mx.zeros_like(point)Current: v0.2.0 — Full PyManopt API parity
| Component | Status | Tests |
|---|---|---|
| Manifolds (15 + variants) | Complete | 281 |
| Optimizers (5) | Complete | 4 smoke + 2 derivative-free |
| Line search (2) | Complete | 3 |
| Tools (diagnostics, testing) | Complete | 6 |
| Backend (PyManopt interop) | Complete | 14 |
| Cross-validation vs pymanopt | 60/60 | 3 |
| Total | 313 |
Remaining gaps: complex-valued manifolds (7 classes) blocked by MLX lacking complex number support.
- Apple Silicon Mac (M1/M2/M3/M4)
- macOS 13.0+ (Ventura or later)
- Python 3.11+
- MLX 0.0.10+
- Absil, Mahony, Sepulchre. "Optimization Algorithms on Matrix Manifolds". Princeton University Press, 2008.
- Boumal. "An Introduction to Optimization on Smooth Manifolds". Cambridge University Press, 2023.
- MLX Documentation
- PyManopt — the reference implementation this library ports
MIT License — see LICENSE for details.
Nitin Borwankar (@nborwankar)