This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
mlx-manopt is an MLX-native Riemannian optimization library for Apple Silicon. It provides GPU-accelerated manifold optimization with a PyManopt-compatible API.
Status: v0.2.0 - Full PyManopt API parity. 313/313 tests passing, 60/60 cross-validation vs pymanopt. All 15 real-valued pymanopt manifolds + 5 optimizers implemented. Complex-valued manifolds blocked by MLX.
- Manifold tests: 281 tests covering Sphere (+ SubspaceIntersection variants), Stiefel (QR + polar + multi-k), Grassmann (+ multi-k), Euclidean, Symmetric, SkewSymmetric, Oblique, SPD (+ multi-k), PSDFixedRank, Elliptope, Product, Positive, SpecialOrthogonalGroup (SO(n)), FixedRankEmbedded, PoincareBall
- Optimizer tests: 4 smoke tests for SteepestDescent, ConjugateGradient, TrustRegions + NelderMead, ParticleSwarm tests
- Other tests: Diagnostics, line search, backend integration (21 tests)
- Cross-validation:
tests/cross_validate_pymanopt.py— 60/60 same-input comparisons vs pymanopt - Run tests:
PATH=/Users/nitin/anaconda3/envs/manopt/bin:$PATH python -m pytest tests/ -v
ALWAYS use the manopt conda env. The base conda env has a broken MLX Metal backend.
# Run any command in this project with:
PATH=/Users/nitin/anaconda3/envs/manopt/bin:$PATH python ...
PATH=/Users/nitin/anaconda3/envs/manopt/bin:$PATH python -m pytest tests/ -vDo NOT use the base anaconda python — it will fail with Failed to load the default metallib.
# Install in development mode
pip install -e .
# Install with dev dependencies
pip install -e ".[dev]"
# Run tests (use manopt conda env)
PATH=/Users/nitin/anaconda3/envs/manopt/bin:$PATH python -m pytest tests/
# Linting and formatting
ruff check src/
ruff format src/
black src/
# Type checking
mypy src/
# Verify imports work
python -c "from mlx_manopt import Sphere, SteepestDescent; print('OK')"src/
├── __init__.py # Main package exports
├── manifolds/
│ ├── manifold.py # Manifold, RiemannianSubmanifold base classes
│ ├── sphere.py # Unit sphere S^{n-1}, SubspaceIntersection variants
│ ├── stiefel.py # Orthonormal matrices St(n,p)
│ ├── grassmann.py # Subspaces Gr(n,p)
│ ├── euclidean.py # Unconstrained R^n, Symmetric, SkewSymmetric
│ ├── oblique.py # Oblique manifold (unit-norm columns)
│ ├── positive_definite.py # Symmetric positive definite
│ ├── psd.py # PSDFixedRank, Elliptope
│ ├── positive.py # Positive matrices P(m,n)
│ ├── hyperbolic.py # Poincaré ball B(n)
│ ├── group.py # SpecialOrthogonalGroup SO(n)
│ ├── fixed_rank.py # Fixed-rank matrices (SVD param)
│ └── product.py # Product manifolds
├── optimizers/
│ ├── optimizer.py # Optimizer, OptimizerResult
│ ├── steepest_descent.py
│ ├── conjugate_gradient.py
│ ├── trust_regions.py
│ ├── nelder_mead.py # Derivative-free simplex
│ ├── particle_swarm.py # Derivative-free population
│ └── line_search.py # BackTrackingLineSearcher, AdaptiveLineSearcher
├── core/
│ └── problem.py # Problem definition
├── tools/
│ ├── linalg.py # Matrix exp/log, batched operations
│ ├── multi.py # Vectorized matrix ops (multitransp, etc.)
│ ├── printer.py # VoidPrinter, ColumnPrinter
│ ├── diagnostics.py # check_gradient, check_hessian, etc.
│ └── testing.py # Gradient/Hessian computation utilities
└── backends/
└── mlx_backend.py # PyManopt Backend ABC implementation
Every manifold implements:
- Required:
inner_product,norm,projection,random_point,random_tangent_vector,zero_vector - Optional:
dist,exp,log,retraction,transport - Gradient conversion:
euclidean_to_riemannian_gradient
# 1. Define manifold constraint
manifold = Sphere(100)
# 2. Define cost function (gradient computed via mx.grad automatically)
def cost(x):
return mx.sum(x ** 2)
# 3. Create problem
problem = Problem(manifold, cost)
# 4. Optimize
result = SteepestDescent().run(problem)- All operations use
mx.array, not NumPy - Use
mx.grad()for automatic differentiation - Use
mx.vmap()for batched operations - Call
mx.eval()only at optimization boundaries (lazy evaluation)
All core components complete:
- ✅ Phase 1: Sphere + SteepestDescent + BackTrackingLineSearcher
- ✅ Phase 2: Stiefel, Grassmann, Euclidean, Symmetric, Product
- ✅ Phase 3: ConjugateGradient, TrustRegions
- ✅ Phase 4: SymmetricPositiveDefinite (matrix_exp via Padé, matrix_log via eigendecomposition)
- ✅ Phase 5: Positive, SpecialOrthogonalGroup (SO(n)), FixedRankEmbedded + utility additions (multiskew, matrix_log_general)
- ✅ Phase 6: PoincareBall — Poincaré ball model of hyperbolic space (pure Python port, conformal metric, Möbius addition, exp/log/dist)
- ✅ Phase 7: Gap analysis — SkewSymmetric, Oblique, PSDFixedRank, Elliptope + multi-k variants for Stiefel/Grassmann/SPD + solve_continuous_lyapunov utility
- ✅ Phase 8: PyManopt structural conformance — tools/, multi.py, printer.py, diagnostics.py, testing.py, NelderMead, ParticleSwarm, SphereSubspaceIntersection, SphereSubspaceComplementIntersection
- ✅ Phase 9: API parity — all constructor param names, file names, class names match pymanopt exactly. Problem class accepts riemannian_gradient/riemannian_hessian/preconditioner.
- Unit test each manifold method individually
- Verify constraint satisfaction (e.g., ||x|| = 1 for Sphere)
- Check gradient conversion via finite differences
- Integration tests: full optimization with known solutions
- mlx-hyperbolic: Hyperbolic geometry in MLX (sister project)
- PyManopt: Reference implementation (NumPy-based)
- Geoopt: PyTorch manifold optimization
- Do NOT make unilateral decisions - especially about test configs, architecture, or API changes
- Proceed step-by-step when making major decisions and get approval first
- When faced with a decision, ASK for permission - do not assume
- Do NOT modify tests without explicit approval
- Apples-to-apples comparison with PyManopt is the goal - tests should mirror pymanopt exactly
- pymanopt-reference/ contains the reference pymanopt repo for comparison
- If pymanopt tests fail due to environment issues (e.g., TensorFlow), report and ask for guidance
- Match PyManopt API exactly - no gratuitous differences
- All API matching is documented in
API_MATCH.md - Any deviation from PyManopt API requires explicit approval and documentation