Skip to content

Latest commit

 

History

History
170 lines (131 loc) · 7.21 KB

File metadata and controls

170 lines (131 loc) · 7.21 KB

CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.

Project Overview

mlx-manopt is an MLX-native Riemannian optimization library for Apple Silicon. It provides GPU-accelerated manifold optimization with a PyManopt-compatible API.

Status: v0.2.0 - Full PyManopt API parity. 313/313 tests passing, 60/60 cross-validation vs pymanopt. All 15 real-valued pymanopt manifolds + 5 optimizers implemented. Complex-valued manifolds blocked by MLX.

Test Suite Summary

  • Manifold tests: 281 tests covering Sphere (+ SubspaceIntersection variants), Stiefel (QR + polar + multi-k), Grassmann (+ multi-k), Euclidean, Symmetric, SkewSymmetric, Oblique, SPD (+ multi-k), PSDFixedRank, Elliptope, Product, Positive, SpecialOrthogonalGroup (SO(n)), FixedRankEmbedded, PoincareBall
  • Optimizer tests: 4 smoke tests for SteepestDescent, ConjugateGradient, TrustRegions + NelderMead, ParticleSwarm tests
  • Other tests: Diagnostics, line search, backend integration (21 tests)
  • Cross-validation: tests/cross_validate_pymanopt.py — 60/60 same-input comparisons vs pymanopt
  • Run tests: PATH=/Users/nitin/anaconda3/envs/manopt/bin:$PATH python -m pytest tests/ -v

Conda Environment

ALWAYS use the manopt conda env. The base conda env has a broken MLX Metal backend.

# Run any command in this project with:
PATH=/Users/nitin/anaconda3/envs/manopt/bin:$PATH python ...
PATH=/Users/nitin/anaconda3/envs/manopt/bin:$PATH python -m pytest tests/ -v

Do NOT use the base anaconda python — it will fail with Failed to load the default metallib.

Build & Development Commands

# Install in development mode
pip install -e .

# Install with dev dependencies
pip install -e ".[dev]"

# Run tests (use manopt conda env)
PATH=/Users/nitin/anaconda3/envs/manopt/bin:$PATH python -m pytest tests/

# Linting and formatting
ruff check src/
ruff format src/
black src/

# Type checking
mypy src/

# Verify imports work
python -c "from mlx_manopt import Sphere, SteepestDescent; print('OK')"

Architecture

src/
├── __init__.py           # Main package exports
├── manifolds/
│   ├── manifold.py       # Manifold, RiemannianSubmanifold base classes
│   ├── sphere.py         # Unit sphere S^{n-1}, SubspaceIntersection variants
│   ├── stiefel.py        # Orthonormal matrices St(n,p)
│   ├── grassmann.py      # Subspaces Gr(n,p)
│   ├── euclidean.py      # Unconstrained R^n, Symmetric, SkewSymmetric
│   ├── oblique.py        # Oblique manifold (unit-norm columns)
│   ├── positive_definite.py  # Symmetric positive definite
│   ├── psd.py            # PSDFixedRank, Elliptope
│   ├── positive.py       # Positive matrices P(m,n)
│   ├── hyperbolic.py     # Poincaré ball B(n)
│   ├── group.py          # SpecialOrthogonalGroup SO(n)
│   ├── fixed_rank.py     # Fixed-rank matrices (SVD param)
│   └── product.py        # Product manifolds
├── optimizers/
│   ├── optimizer.py      # Optimizer, OptimizerResult
│   ├── steepest_descent.py
│   ├── conjugate_gradient.py
│   ├── trust_regions.py
│   ├── nelder_mead.py    # Derivative-free simplex
│   ├── particle_swarm.py # Derivative-free population
│   └── line_search.py    # BackTrackingLineSearcher, AdaptiveLineSearcher
├── core/
│   └── problem.py        # Problem definition
├── tools/
│   ├── linalg.py         # Matrix exp/log, batched operations
│   ├── multi.py          # Vectorized matrix ops (multitransp, etc.)
│   ├── printer.py        # VoidPrinter, ColumnPrinter
│   ├── diagnostics.py    # check_gradient, check_hessian, etc.
│   └── testing.py        # Gradient/Hessian computation utilities
└── backends/
    └── mlx_backend.py    # PyManopt Backend ABC implementation

Key Design Patterns

Manifold Interface

Every manifold implements:

  • Required: inner_product, norm, projection, random_point, random_tangent_vector, zero_vector
  • Optional: dist, exp, log, retraction, transport
  • Gradient conversion: euclidean_to_riemannian_gradient

Problem + Optimizer Pattern

# 1. Define manifold constraint
manifold = Sphere(100)

# 2. Define cost function (gradient computed via mx.grad automatically)
def cost(x):
    return mx.sum(x ** 2)

# 3. Create problem
problem = Problem(manifold, cost)

# 4. Optimize
result = SteepestDescent().run(problem)

MLX-Specific Considerations

  • All operations use mx.array, not NumPy
  • Use mx.grad() for automatic differentiation
  • Use mx.vmap() for batched operations
  • Call mx.eval() only at optimization boundaries (lazy evaluation)

Implementation Status

All core components complete:

  • Phase 1: Sphere + SteepestDescent + BackTrackingLineSearcher
  • Phase 2: Stiefel, Grassmann, Euclidean, Symmetric, Product
  • Phase 3: ConjugateGradient, TrustRegions
  • Phase 4: SymmetricPositiveDefinite (matrix_exp via Padé, matrix_log via eigendecomposition)
  • Phase 5: Positive, SpecialOrthogonalGroup (SO(n)), FixedRankEmbedded + utility additions (multiskew, matrix_log_general)
  • Phase 6: PoincareBall — Poincaré ball model of hyperbolic space (pure Python port, conformal metric, Möbius addition, exp/log/dist)
  • Phase 7: Gap analysis — SkewSymmetric, Oblique, PSDFixedRank, Elliptope + multi-k variants for Stiefel/Grassmann/SPD + solve_continuous_lyapunov utility
  • Phase 8: PyManopt structural conformance — tools/, multi.py, printer.py, diagnostics.py, testing.py, NelderMead, ParticleSwarm, SphereSubspaceIntersection, SphereSubspaceComplementIntersection
  • Phase 9: API parity — all constructor param names, file names, class names match pymanopt exactly. Problem class accepts riemannian_gradient/riemannian_hessian/preconditioner.

Testing Strategy

  • Unit test each manifold method individually
  • Verify constraint satisfaction (e.g., ||x|| = 1 for Sphere)
  • Check gradient conversion via finite differences
  • Integration tests: full optimization with known solutions

Related Projects

  • mlx-hyperbolic: Hyperbolic geometry in MLX (sister project)
  • PyManopt: Reference implementation (NumPy-based)
  • Geoopt: PyTorch manifold optimization

Critical Guidelines for Claude

Decision-Making Protocol

  1. Do NOT make unilateral decisions - especially about test configs, architecture, or API changes
  2. Proceed step-by-step when making major decisions and get approval first
  3. When faced with a decision, ASK for permission - do not assume

Test Suite Rules

  1. Do NOT modify tests without explicit approval
  2. Apples-to-apples comparison with PyManopt is the goal - tests should mirror pymanopt exactly
  3. pymanopt-reference/ contains the reference pymanopt repo for comparison
  4. If pymanopt tests fail due to environment issues (e.g., TensorFlow), report and ask for guidance

API Compatibility

  1. Match PyManopt API exactly - no gratuitous differences
  2. All API matching is documented in API_MATCH.md
  3. Any deviation from PyManopt API requires explicit approval and documentation