2
2
# using Revise
3
3
using Test
4
4
using LinearAlgebra
5
+ using ManifoldsBase
5
6
using Manifolds, Manopt
7
+ using Optim
6
8
using FiniteDifferences, ManifoldDiff
7
9
8
10
# #
44
46
# #
45
47
end
46
48
49
+
50
+ @testset " Optim.jl ManifoldWrapper example from mateuszbaran (copied to catch issues on future changes)" begin
51
+ # Example modified from: https://gist.github.com/mateuszbaran/0354c0edfb9cdf25e084a2b915816a09
52
+ # #
53
+
54
+ """
55
+ ManifoldWrapper{TM<:AbstractManifold} <: Optim.Manifold
56
+
57
+ Adapts Manifolds.jl manifolds for use in Optim.jl
58
+ """
59
+ struct ManifoldWrapper{TM<: AbstractManifold } <: Optim.Manifold
60
+ M:: TM
61
+ end
62
+
63
+ function Optim. retract! (M:: ManifoldWrapper , x)
64
+ ManifoldsBase. embed_project! (M. M, x, x)
65
+ return x
66
+ end
67
+
68
+ function Optim. project_tangent! (M:: ManifoldWrapper , g, x)
69
+ ManifoldsBase. embed_project! (M. M, g, x, g)
70
+ return g
71
+ end
72
+
73
+ # example usage of Manifolds.jl manifolds in Optim.jl
74
+ M = Manifolds. Sphere (2 )
75
+ x0 = [1.0 , 0.0 , 0.0 ]
76
+ q = [0.0 , 1.0 , 0.0 ]
77
+
78
+ f (p) = 0.5 * distance (M, p, q)^ 2
79
+
80
+ # manual gradient
81
+ function g! (X, p)
82
+ log! (M, X, p, q)
83
+ X .*= - 1
84
+ println (p, X)
85
+ end
86
+
87
+ # #
88
+
89
+ sol = optimize (f, g!, x0, ConjugateGradient (; manifold= ManifoldWrapper (M)))
90
+ @test isapprox ([0 ,1 ,0. ], sol. minimizer; atol= 1e-8 )
91
+
92
+
93
+ # # finitediff gradient (non-manual)
94
+
95
+ r_backend = ManifoldDiff. TangentDiffBackend (
96
+ ManifoldDiff. FiniteDifferencesBackend ()
97
+ )
98
+ function g_FD! (X,p)
99
+ X .= ManifoldDiff. gradient (M, f, p, r_backend)
100
+ X
101
+ end
102
+
103
+ #
104
+ x0 = [1.0 , 0.0 , 0.0 ]
105
+
106
+ sol = optimize (f, g_FD!, x0, ConjugateGradient (; manifold= ManifoldWrapper (M)))
107
+ @test isapprox ([0 ,1 ,0. ], sol. minimizer; atol= 1e-8 )
108
+
109
+ # #
110
+
111
+ # x0 = [1.0, 0.0, 0.0]
112
+ # # internal ForwardDfif doesnt work out the box on Manifolds
113
+ # sol = optimize(f, x0, ConjugateGradient(; manifold=ManifoldWrapper(M)); autodiff=:forward )
114
+ # @test isapprox([0,1,0.], sol.minimizer; atol=1e-8)
115
+
116
+ # #
117
+ end
118
+
47
119
# #
0 commit comments