Skip to content

Commit d7bd96e

Browse files
committed
add permanent check from Mateusz' example
1 parent fe893b0 commit d7bd96e

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

test/manifolds/manifolddiff.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
# using Revise
33
using Test
44
using LinearAlgebra
5+
using ManifoldsBase
56
using Manifolds, Manopt
7+
using Optim
68
using FiniteDifferences, ManifoldDiff
79

810
##
@@ -44,4 +46,74 @@ end
4446
##
4547
end
4648

49+
50+
@testset "Optim.jl ManifoldWrapper example from mateuszbaran (copied to catch issues on future changes)" begin
51+
# Example modified from: https://gist.github.com/mateuszbaran/0354c0edfb9cdf25e084a2b915816a09
52+
##
53+
54+
"""
55+
ManifoldWrapper{TM<:AbstractManifold} <: Optim.Manifold
56+
57+
Adapts Manifolds.jl manifolds for use in Optim.jl
58+
"""
59+
struct ManifoldWrapper{TM<:AbstractManifold} <: Optim.Manifold
60+
M::TM
61+
end
62+
63+
function Optim.retract!(M::ManifoldWrapper, x)
64+
ManifoldsBase.embed_project!(M.M, x, x)
65+
return x
66+
end
67+
68+
function Optim.project_tangent!(M::ManifoldWrapper, g, x)
69+
ManifoldsBase.embed_project!(M.M, g, x, g)
70+
return g
71+
end
72+
73+
# example usage of Manifolds.jl manifolds in Optim.jl
74+
M = Manifolds.Sphere(2)
75+
x0 = [1.0, 0.0, 0.0]
76+
q = [0.0, 1.0, 0.0]
77+
78+
f(p) = 0.5 * distance(M, p, q)^2
79+
80+
# manual gradient
81+
function g!(X, p)
82+
log!(M, X, p, q)
83+
X .*= -1
84+
println(p, X)
85+
end
86+
87+
##
88+
89+
sol = optimize(f, g!, x0, ConjugateGradient(; manifold=ManifoldWrapper(M)))
90+
@test isapprox([0,1,0.], sol.minimizer; atol=1e-8)
91+
92+
93+
## finitediff gradient (non-manual)
94+
95+
r_backend = ManifoldDiff.TangentDiffBackend(
96+
ManifoldDiff.FiniteDifferencesBackend()
97+
)
98+
function g_FD!(X,p)
99+
X .= ManifoldDiff.gradient(M, f, p, r_backend)
100+
X
101+
end
102+
103+
#
104+
x0 = [1.0, 0.0, 0.0]
105+
106+
sol = optimize(f, g_FD!, x0, ConjugateGradient(; manifold=ManifoldWrapper(M)))
107+
@test isapprox([0,1,0.], sol.minimizer; atol=1e-8)
108+
109+
##
110+
111+
# x0 = [1.0, 0.0, 0.0]
112+
# # internal ForwardDfif doesnt work out the box on Manifolds
113+
# sol = optimize(f, x0, ConjugateGradient(; manifold=ManifoldWrapper(M)); autodiff=:forward )
114+
# @test isapprox([0,1,0.], sol.minimizer; atol=1e-8)
115+
116+
##
117+
end
118+
47119
##

0 commit comments

Comments
 (0)