Skip to content

Commit c855c31

Browse files
authored
Merge pull request #1697 from JuliaRobotics/23Q1/test/manidiff01
basic manifold diff sandbox test
2 parents a14125f + 5791e0c commit c855c31

File tree

5 files changed

+291
-1
lines changed

5 files changed

+291
-1
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@ DistributedFactorGraphs = "b5cc3c7e-6572-11e9-2517-99fb8daf2f04"
1717
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1818
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1919
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
20+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2021
FunctionalStateMachine = "3e9e306e-7e3c-11e9-12d2-8f8f67a2f951"
2122
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
2223
KernelDensityEstimate = "2472808a-b354-52ea-a80e-1658a3c6056d"
2324
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2425
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
2526
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
2627
ManifoldsBase = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb"
28+
ManifoldDiff = "af67fdf4-a580-4b9f-bbec-742ef357defd"
2729
MetaGraphs = "626554b9-1ddb-594c-aa3c-2596fe9399a5"
2830
NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c"
2931
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
@@ -82,9 +84,10 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
8284
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
8385
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
8486
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
87+
Manopt = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5"
8588
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
8689
Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc"
8790
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8891

8992
[targets]
90-
test = ["DifferentialEquations", "Flux", "Graphs", "InteractiveUtils", "Interpolations", "Pkg", "Rotations", "Test"]
93+
test = ["DifferentialEquations", "Flux", "Graphs", "Manopt", "InteractiveUtils", "Interpolations", "Pkg", "Rotations", "Test"]

src/IncrementalInference.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ using Reexport
1515
using Manifolds
1616
using RecursiveArrayTools: ArrayPartition
1717
export ArrayPartition
18+
using ManifoldDiff
19+
using FiniteDifferences
1820

1921
using OrderedCollections: OrderedDict
2022

src/ManifoldsExtentions.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,50 @@
1+
2+
## ================================================================================================
3+
## Manifold and ManifoldDiff use with Optim
4+
## ================================================================================================
5+
6+
# Modified from: https://gist.github.com/mateuszbaran/0354c0edfb9cdf25e084a2b915816a09
7+
"""
8+
ManifoldWrapper{TM<:AbstractManifold} <: Optim.Manifold
9+
10+
Adapts Manifolds.jl manifolds for use in Optim.jl
11+
"""
12+
struct ManifoldWrapper{TM<:AbstractManifold} <: Optim.Manifold
13+
M::TM
14+
end
15+
16+
function Optim.retract!(M::ManifoldWrapper, x)
17+
ManifoldsBase.embed_project!(M.M, x, x)
18+
return x
19+
end
20+
21+
function Optim.project_tangent!(M::ManifoldWrapper, g, x)
22+
ManifoldsBase.embed_project!(M.M, g, x, g)
23+
return g
24+
end
25+
26+
# experimental
27+
function optimizeManifold_FD(
28+
M::AbstractManifold,
29+
cost::Function,
30+
x0::AbstractArray;
31+
algorithm = Optim.ConjugateGradient(; manifold=ManifoldWrapper(M))
32+
)
33+
# finitediff setup
34+
r_backend = ManifoldDiff.TangentDiffBackend(
35+
ManifoldDiff.FiniteDifferencesBackend()
36+
)
37+
38+
## finitediff gradient (non-manual)
39+
function costgrad_FD!(X,p)
40+
X .= ManifoldDiff.gradient(M, cost, p, r_backend)
41+
X
42+
end
43+
44+
Optim.optimize(cost, costgrad_FD!, x0, algorithm)
45+
end
46+
47+
148
## ================================================================================================
249
## AbstractPowerManifold with N as field to avoid excessive compiling time.
350
## ================================================================================================

test/manifolds/manifolddiff.jl

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
2+
# using Revise
3+
using Test
4+
using LinearAlgebra
5+
using IncrementalInference
6+
using ManifoldsBase
7+
using Manifolds, Manopt
8+
import Optim
9+
using FiniteDifferences, ManifoldDiff
10+
import Rotations as _Rot
11+
12+
##
13+
14+
# finitediff setup
15+
r_backend = ManifoldDiff.TangentDiffBackend(
16+
ManifoldDiff.FiniteDifferencesBackend()
17+
)
18+
19+
##
20+
@testset "ManifoldDiff, Basic test" begin
21+
##
22+
23+
# problem setup
24+
n = 100
25+
σ = π / 8
26+
M = Manifolds.Sphere(2)
27+
p = 1 / sqrt(2) * [1.0, 0.0, 1.0]
28+
data = [exp(M, p, σ * rand(M; vector_at=p)) for i in 1:n];
29+
30+
# objective function
31+
f(M, p) = sum(1 / (2 * n) * distance.(Ref(M), Ref(p), data) .^ 2)
32+
# f_(p) = f(M,p)
33+
34+
# non-manual: intrinsic finite differences gradient
35+
function grad_f_FD(M,p)
36+
f_(p_) = f(M,p_)
37+
ManifoldDiff.gradient(M, f_, p, r_backend)
38+
end
39+
# manual gradient
40+
# grad_f(M, p) = sum(1 / n * grad_distance.(Ref(M), data, Ref(p)));
41+
42+
43+
# and solve
44+
@time m1 = gradient_descent(M, f, grad_f_FD, data[1])
45+
46+
@info "Basic Manopt test" string(m1')
47+
@test isapprox(p, m1; atol=0.15)
48+
49+
##
50+
end
51+
52+
##
53+
54+
"""
55+
ManifoldWrapper{TM<:AbstractManifold} <: Optim.Manifold
56+
57+
Adapts Manifolds.jl manifolds for use in Optim.jl
58+
"""
59+
struct ManifoldWrapper{TM<:AbstractManifold} <: Optim.Manifold
60+
M::TM
61+
end
62+
63+
function Optim.retract!(M::ManifoldWrapper, x)
64+
ManifoldsBase.embed_project!(M.M, x, x)
65+
return x
66+
end
67+
68+
function Optim.project_tangent!(M::ManifoldWrapper, g, x)
69+
ManifoldsBase.embed_project!(M.M, g, x, g)
70+
return g
71+
end
72+
73+
##
74+
@testset "Optim.jl ManifoldWrapper example from mateuszbaran (copied to catch issues on future changes)" begin
75+
##
76+
# Example modified from: https://gist.github.com/mateuszbaran/0354c0edfb9cdf25e084a2b915816a09
77+
78+
# example usage of Manifolds.jl manifolds in Optim.jl
79+
M = Manifolds.Sphere(2)
80+
x0 = [1.0, 0.0, 0.0]
81+
q = [0.0, 1.0, 0.0]
82+
83+
f(p) = 0.5 * distance(M, p, q)^2
84+
85+
# manual gradient
86+
function g!(X, p)
87+
log!(M, X, p, q)
88+
X .*= -1
89+
println(p, X)
90+
end
91+
92+
##
93+
94+
sol = Optim.optimize(f, g!, x0, Optim.ConjugateGradient(; manifold=ManifoldWrapper(M)))
95+
@test isapprox([0,1,0.], sol.minimizer; atol=1e-8)
96+
97+
98+
## finitediff gradient (non-manual)
99+
100+
function g_FD!(X,p)
101+
X .= ManifoldDiff.gradient(M, f, p, r_backend)
102+
X
103+
end
104+
105+
#
106+
x0 = [1.0, 0.0, 0.0]
107+
108+
sol = Optim.optimize(f, g_FD!, x0, Optim.ConjugateGradient(; manifold=ManifoldWrapper(M)))
109+
@test isapprox([0,1,0.], sol.minimizer; atol=1e-8)
110+
111+
##
112+
113+
# x0 = [1.0, 0.0, 0.0]
114+
# # internal ForwardDfif doesnt work out the box on Manifolds
115+
# sol = Optim.optimize(f, x0, Optim.ConjugateGradient(; manifold=ManifoldWrapper(M)); autodiff=:forward )
116+
# @test isapprox([0,1,0.], sol.minimizer; atol=1e-8)
117+
118+
##
119+
end
120+
121+
122+
@testset "Modified Manifolds.jl ManifoldWrapper <: Optim.Manifold for SpecialEuclidean(2)" begin
123+
##
124+
125+
M = Manifolds.SpecialEuclidean(2)
126+
e0 = ArrayPartition([0,0.], [1 0; 0 1.])
127+
128+
x0 = deepcopy(e0)
129+
Cq = 9*ones(3)
130+
while 1.5 < abs(Cq[3])
131+
@show Cq .= randn(3)
132+
# Cq[3] = 1.5 # breaks ConjugateGradient
133+
end
134+
q = exp(M,e0,hat(M,e0,Cq))
135+
136+
f(p) = distance(M, p, q)^2
137+
138+
## finitediff gradient (non-manual)
139+
function g_FD!(X,p)
140+
X .= ManifoldDiff.gradient(M, f, p, r_backend)
141+
X
142+
end
143+
144+
## sanity check gradients
145+
146+
X = hat(M, e0, zeros(3))
147+
g_FD!(X, q)
148+
# gradient at the optimal point should be zero
149+
@show X_ = [X.x[1][:]; X.x[2][:]]
150+
@test isapprox(0, sum(abs.(X_)); atol=1e-8 )
151+
152+
# gradient not the optimal point should be non-zero
153+
g_FD!(X, e0)
154+
@show X_ = [X.x[1][:]; X.x[2][:]]
155+
@test 0.01 < sum(abs.(X_))
156+
157+
## do optimization
158+
x0 = deepcopy(e0)
159+
sol = Optim.optimize(f, g_FD!, x0, Optim.ConjugateGradient(; manifold=ManifoldWrapper(M)))
160+
Cq .= randn(3)
161+
# Cq[
162+
@show sol.minimizer
163+
@test isapprox( f(sol.minimizer), 0; atol=1e-3 )
164+
@test isapprox( 0, sum(abs.(log(M, e0, compose(M, inv(M,q), sol.minimizer)))); atol=1e-5)
165+
166+
##
167+
end
168+
169+
170+
@testset "Modified ManifoldsWrapper for Optim.Manifolds, SpecialEuclidean(3)" begin
171+
##
172+
173+
174+
M = Manifolds.SpecialEuclidean(3)
175+
e0 = ArrayPartition([0,0,0.], Matrix(_Rot.RotXYZ(0,0,0.)))
176+
177+
x0 = deepcopy(e0)
178+
Cq = 0.5*randn(6)
179+
q = exp(M,e0,hat(M,e0,Cq))
180+
181+
f(p) = distance(M, p, q)^2
182+
183+
## finitediff gradient (non-manual)
184+
function g_FD!(X,p)
185+
X .= ManifoldDiff.gradient(M, f, p, r_backend)
186+
X
187+
end
188+
189+
## sanity check gradients
190+
191+
X = hat(M, e0, zeros(6))
192+
g_FD!(X, q)
193+
194+
@show X_ = [X.x[1][:]; X.x[2][:]]
195+
# gradient at the optimal point should be zero
196+
@test isapprox(0, sum(abs.(X_)); atol=1e-8 )
197+
198+
# gradient not the optimal point should be non-zero
199+
g_FD!(X, e0)
200+
@show X_ = [X.x[1][:]; X.x[2][:]]
201+
@test 0.01 < sum(abs.(X_))
202+
203+
## do optimization
204+
x0 = deepcopy(e0)
205+
sol = Optim.optimize(f, g_FD!, x0, Optim.ConjugateGradient(; manifold=ManifoldWrapper(M)))
206+
# Cq .= 0.5*randn(6)
207+
# Cq[
208+
@show sol.minimizer
209+
@test isapprox( f(sol.minimizer), 0; atol=1e-3 )
210+
@test isapprox( 0, sum(abs.(log(M, e0, compose(M, inv(M,q), sol.minimizer)))); atol=1e-3)
211+
212+
213+
##
214+
end
215+
216+
217+
@testset "Optim.Manifolds, SpecialEuclidean(3), using IIF.optimizeManifold_FD" begin
218+
##
219+
220+
M = Manifolds.SpecialEuclidean(3)
221+
e0 = ArrayPartition([0,0,0.], Matrix(_Rot.RotXYZ(0,0,0.)))
222+
223+
x0 = deepcopy(e0)
224+
Cq = 0.5*randn(6)
225+
q = exp(M,e0,hat(M,e0,Cq))
226+
227+
f(p) = distance(M, p, q)^2
228+
229+
sol = IncrementalInference.optimizeManifold_FD(M,f,x0)
230+
231+
@show sol.minimizer
232+
@test isapprox( f(sol.minimizer), 0; atol=1e-3 )
233+
@test isapprox( 0, sum(abs.(log(M, e0, compose(M, inv(M,q), sol.minimizer)))); atol=1e-5)
234+
235+
236+
##
237+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ end
1111

1212
if TEST_GROUP in ["all", "basic_functional_group"]
1313
# more frequent stochasic failures from numerics
14+
include("manifolds/manifolddiff.jl")
1415
include("testSpecialEuclidean2Mani.jl")
1516
include("testEuclidDistance.jl")
1617

0 commit comments

Comments
 (0)