@@ -20,9 +20,9 @@ project_tangent(M::Manifold, x) = project_tangent!(M, similar(x), x)
2020retract (M:: Manifold , x) = retract! (M, copy (x))
2121
2222# Fake objective function implementing a retraction
23- mutable struct ManifoldObjective{T<: NLSolversBase.AbstractObjective } < :
23+ mutable struct ManifoldObjective{M <: Manifold , T<: NLSolversBase.AbstractObjective } < :
2424 NLSolversBase. AbstractObjective
25- manifold:: Manifold
25+ manifold:: M
2626 inner_obj:: T
2727end
2828# TODO : is it safe here to call retract! and change x?
@@ -52,6 +52,20 @@ function NLSolversBase.value_gradient!(obj::ManifoldObjective, x)
5252 return value (obj. inner_obj)
5353end
5454
55+ # In general, we have to compute the gradient/Jacobian separately as it has to be projected
56+ function NLSolversBase. jvp! (obj:: ManifoldObjective , x, v)
57+ xin = retract (obj. manifold, x)
58+ gradient! (obj. inner_obj, xin)
59+ project_tangent! (obj. manifold, gradient (obj. inner_obj), xin)
60+ return dot (gradient (obj. inner_obj), v)
61+ end
62+ function NLSolversBase. value_jvp! (obj:: ManifoldObjective , x, v)
63+ xin = retract (obj. manifold, x)
64+ value_gradient! (obj. inner_obj, xin)
65+ project_tangent! (obj. manifold, gradient (obj. inner_obj), xin)
66+ return value (obj. inner_obj), dot (gradient (obj. inner_obj), v)
67+ end
68+
5569""" Flat Euclidean space {R,C}^N, with projections equal to the identity."""
5670struct Flat <: Manifold end
5771# all the functions below are no-ops, and therefore the generated code
@@ -62,6 +76,10 @@ retract!(M::Flat, x) = x
6276project_tangent (M:: Flat , g, x) = g
6377project_tangent! (M:: Flat , g, x) = g
6478
79+ # Optimizations for `Flat` manifold
80+ NLSolversBase. jvp! (obj:: ManifoldObjective{Flat} , x, v) = jvp! (obj. inner_obj, x, v)
81+ NLSolversBase. value_jvp! (obj:: ManifoldObjective{Flat} , x, v) = value_jvp! (obj. inner_obj, x, v)
82+
6583""" Spherical manifold {|x| = 1}."""
6684struct Sphere <: Manifold end
6785retract! (S:: Sphere , x) = (x ./= norm (x))
0 commit comments