Skip to content

Commit 1cbb4ff

Browse files
committed
new optimizeManifold_FD convenience function
1 parent c53422d commit 1cbb4ff

File tree

4 files changed

+51
-3
lines changed

4 files changed

+51
-3
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@ DistributedFactorGraphs = "b5cc3c7e-6572-11e9-2517-99fb8daf2f04"
1717
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1818
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1919
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
20+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2021
FunctionalStateMachine = "3e9e306e-7e3c-11e9-12d2-8f8f67a2f951"
2122
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
2223
KernelDensityEstimate = "2472808a-b354-52ea-a80e-1658a3c6056d"
2324
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2425
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
2526
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
2627
ManifoldsBase = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb"
28+
ManifoldDiff = "af67fdf4-a580-4b9f-bbec-742ef357defd"
2729
MetaGraphs = "626554b9-1ddb-594c-aa3c-2596fe9399a5"
2830
NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c"
2931
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
@@ -78,12 +80,10 @@ julia = "1.8"
7880

7981
[extras]
8082
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
81-
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
8283
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
8384
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
8485
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
8586
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
86-
ManifoldDiff = "af67fdf4-a580-4b9f-bbec-742ef357defd"
8787
Manopt = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5"
8888
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
8989
Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc"

src/IncrementalInference.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ using Reexport
1515
using Manifolds
1616
using RecursiveArrayTools: ArrayPartition
1717
export ArrayPartition
18+
using ManifoldDiff
1819

1920
using OrderedCollections: OrderedDict
2021

src/ManifoldsExtentions.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,50 @@
1+
2+
## ================================================================================================
3+
## Manifold and ManifoldDiff use with Optim
4+
## ================================================================================================
5+
6+
# Modified from: https://gist.github.com/mateuszbaran/0354c0edfb9cdf25e084a2b915816a09
7+
"""
8+
ManifoldWrapper{TM<:AbstractManifold} <: Optim.Manifold
9+
10+
Adapts Manifolds.jl manifolds for use in Optim.jl
11+
"""
12+
struct ManifoldWrapper{TM<:AbstractManifold} <: Optim.Manifold
13+
M::TM
14+
end
15+
16+
function Optim.retract!(M::ManifoldWrapper, x)
17+
ManifoldsBase.embed_project!(M.M, x, x)
18+
return x
19+
end
20+
21+
function Optim.project_tangent!(M::ManifoldWrapper, g, x)
22+
ManifoldsBase.embed_project!(M.M, g, x, g)
23+
return g
24+
end
25+
26+
# experimental
27+
function optimizeManifold_FD(
28+
M::AbstractManifold,
29+
cost::Function,
30+
x0::AbstractArray;
31+
algorithm = Optim.ConjugateGradient(; manifold=ManifoldWrapper(M))
32+
)
33+
# finitediff setup
34+
r_backend = ManifoldDiff.TangentDiffBackend(
35+
ManifoldDiff.FiniteDifferencesBackend()
36+
)
37+
38+
## finitediff gradient (non-manual)
39+
function costgrad_FD!(X,p)
40+
X .= ManifoldDiff.gradient(M, cost, p, r_backend)
41+
X
42+
end
43+
44+
Optim.optimize(cost, costgrad_FD!, x0, algorithm)
45+
end
46+
47+
148
## ================================================================================================
249
## AbstractPowerManifold with N as field to avoid excessive compiling time.
350
## ================================================================================================

test/manifolds/manifolddiff.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ r_backend = ManifoldDiff.TangentDiffBackend(
2222
# problem setup
2323
n = 100
2424
σ = π / 8
25-
M = Sphere(2)
25+
M = Manifolds.Sphere(2)
2626
p = 1 / sqrt(2) * [1.0, 0.0, 1.0]
2727
data = [exp(M, p, σ * rand(M; vector_at=p)) for i in 1:n];
2828

0 commit comments

Comments
 (0)