Skip to content

nborwankar/mlx-manopt

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

25 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mlx-manopt: Riemannian Optimization on Apple Silicon

MLX-native port of PyManopt for Apple Silicon GPUs. Full API parity — all 15 real-valued manifolds, 5 optimizers, and identical class/parameter names.

+------------------+       +------------------+       +------------------+
|   Cost Function  | ----> |     Problem      | ----> |    Optimizer     |
|   f: M -> R      |       | manifold + cost  |       |  SD / CG / TR   |
+------------------+       +------------------+       +------------------+
        |                          |                          |
        v                          v                          v
  mx.grad (autodiff)        Manifold constraint        Riemannian updates
                            (projection, retraction)   (geodesic steps)

Why Riemannian Optimization?

Many ML problems have natural geometric constraints:

Problem Constraint Manifold
Normalized embeddings ||x|| = 1 Sphere
PCA, subspace learning Orthonormal basis Stiefel
Low-rank matrix completion Fixed rank FixedRankEmbedded
Covariance estimation Positive definite SPD
Hyperbolic embeddings Poincare ball PoincareBall
Dictionary learning Unit-norm atoms Oblique

Riemannian optimization respects these constraints by design, avoiding projection hacks and improving convergence.

Installation

# Requires Apple Silicon Mac with macOS 13+
pip install mlx

# Clone and install
git clone https://github.com/nborwankar/mlx-manopt.git
cd mlx-manopt
pip install -e .

Quick Start

import mlx.core as mx
from mlx_manopt import Sphere, SteepestDescent, Problem

# Find the leading eigenvector of a matrix
n = 100
A = mx.random.normal((n, n))
A = (A + A.T) / 2  # Symmetric matrix

# Minimize -x^T A x on the unit sphere
def cost(x):
    return -mx.sum(x * (A @ x))

problem = Problem(Sphere(n), cost)
result = SteepestDescent().run(problem)

print(f"Converged in {result.iterations} iterations")
print(f"Final cost: {result.cost:.6f}")

Available Manifolds

All 15 real-valued manifolds from PyManopt, with identical class names and constructor parameters:

Manifold Description Dimension
Sphere(n) Unit sphere S^{n-1} in R^n n-1
Stiefel(n, p) Orthonormal n x p matrices np - p(p+1)/2
Grassmann(n, p) p-dimensional subspaces of R^n p(n-p)
Euclidean(*shape) Unconstrained R^n or R^{m x n} product of dims
Symmetric(n) n x n symmetric matrices n(n+1)/2
SkewSymmetric(n) n x n skew-symmetric matrices n(n-1)/2
Oblique(m, n) m x n matrices with unit-norm columns (m-1) x n
SymmetricPositiveDefinite(n) n x n SPD matrices n(n+1)/2
PSDFixedRank(n, k) Rank-k PSD matrices via Y Y^T kn - k(k-1)/2
Elliptope(n, k) Correlation matrices (PSD + diag=1) n(k-1) - k(k-1)/2
Positive(m, n) Matrices with strictly positive entries m x n
PoincareBall(n) Poincare ball model of hyperbolic space n
SpecialOrthogonalGroup(n) Rotation matrices SO(n) n(n-1)/2
FixedRankEmbedded(m, n, k) m x n matrices of fixed rank k (m+n-k) x k
Product([M1, M2, ...]) Cartesian product sum of dims

Variants: SphereSubspaceIntersection, SphereSubspaceComplementIntersection, multi-k versions of Stiefel/Grassmann/SPD.

Available Optimizers

Optimizer Type Best For
SteepestDescent 1st order, gradient Simple problems, debugging
ConjugateGradient 1st order, gradient Large-scale, ill-conditioned problems
TrustRegions 2nd order, Hessian High-accuracy solutions
NelderMead Derivative-free, simplex Non-smooth or noisy objectives
ParticleSwarm Derivative-free, population Global optimization, multimodal

PyManopt Compatibility

mlx-manopt matches PyManopt's public API exactly:

  • Class names: Sphere, Stiefel, ConjugateGradient, etc. — identical
  • Constructor parameters: contraction_factor, beta_rule, rho_prime, etc. — identical
  • File layout: manifolds/, optimizers/, tools/, core/ — same structure
  • Problem signature: Problem(manifold, cost, *, euclidean_gradient=None, riemannian_gradient=None, ...)
  • Backend: Registers as a PyManopt backend for interop (src/backends/mlx_backend.py)

Existing PyManopt code can be ported by changing imports:

# Before (PyManopt + NumPy)
from pymanopt.manifolds import Sphere
from pymanopt.optimizers import SteepestDescent

# After (mlx-manopt + MLX)
from mlx_manopt import Sphere, SteepestDescent

Cross-validated: 60/60 same-input comparisons produce matching results (within float32 tolerance).

Tools and Diagnostics

from mlx_manopt.tools.diagnostics import check_gradient, check_hessian, check_retraction

# Verify your cost function's gradient is correct
check_gradient(problem)

# Verify Hessian
check_hessian(problem)

API Overview

Problem Definition

from mlx_manopt import Problem, Sphere
import mlx.core as mx

manifold = Sphere(100)

def cost(x):
    return mx.sum(x ** 2)

# Gradient computed automatically via mx.grad
problem = Problem(manifold, cost)

# Or provide your own Riemannian gradient directly
problem = Problem(manifold, cost, riemannian_gradient=my_grad_fn)

Optimization

from mlx_manopt import SteepestDescent, ConjugateGradient, TrustRegions

# Configure optimizer
optimizer = ConjugateGradient(
    max_iterations=1000,
    min_gradient_norm=1e-6,
    beta_rule="FletcherReeves",  # or HagerZhang, HestenesStiefel, PolakRibiere, LiuStorey
    verbosity=2,
)

# Run optimization
result = optimizer.run(problem, initial_point=x0)

# Access results
optimal_x = result.point
final_cost = result.cost
num_iters = result.iterations

Custom Manifolds

from mlx_manopt.manifolds import Manifold
import mlx.core as mx

class MyManifold(Manifold):
    def __init__(self, n):
        super().__init__(name="My Manifold", dimension=n)
        self._n = n

    def inner_product(self, point, tangent_vector_a, tangent_vector_b):
        return mx.sum(tangent_vector_a * tangent_vector_b)

    def projection(self, point, vector):
        return vector - mx.sum(vector * point) * point

    def retraction(self, point, tangent_vector):
        y = point + tangent_vector
        return y / mx.linalg.norm(y)

    def random_point(self):
        x = mx.random.normal((self._n,))
        return x / mx.linalg.norm(x)

    def random_tangent_vector(self, point):
        v = mx.random.normal(point.shape)
        v = self.projection(point, v)
        return v / mx.linalg.norm(v)

    def zero_vector(self, point):
        return mx.zeros_like(point)

Project Status

Current: v0.2.0 — Full PyManopt API parity

Component Status Tests
Manifolds (15 + variants) Complete 281
Optimizers (5) Complete 4 smoke + 2 derivative-free
Line search (2) Complete 3
Tools (diagnostics, testing) Complete 6
Backend (PyManopt interop) Complete 14
Cross-validation vs pymanopt 60/60 3
Total 313

Remaining gaps: complex-valued manifolds (7 classes) blocked by MLX lacking complex number support.

Requirements

  • Apple Silicon Mac (M1/M2/M3/M4)
  • macOS 13.0+ (Ventura or later)
  • Python 3.11+
  • MLX 0.0.10+

References

License

MIT License — see LICENSE for details.

Author

Nitin Borwankar (@nborwankar)

About

MLX-native Riemannian optimization for Apple Silicon. Fast manifold optimization with PyManopt-compatible API.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages