Skip to content

Commit 46f799f

Browse files
committed
basic manifold diff sandbox test
1 parent 78d5f54 commit 46f799f

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

test/manifolds/manifolddiff.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
2+
# using Revise
3+
using Test
4+
using LinearAlgebra
5+
using Manifolds, Manopt
6+
using FiniteDifferences, ManifoldDiff
7+
8+
##
9+
10+
@testset "ManifoldDiff, Basic test" begin
11+
##
12+
13+
# finitediff setup
14+
r_backend = ManifoldDiff.TangentDiffBackend(
15+
ManifoldDiff.FiniteDifferencesBackend()
16+
)
17+
18+
# problem setup
19+
n = 100
20+
σ = π / 8
21+
M = Sphere(2)
22+
p = 1 / sqrt(2) * [1.0, 0.0, 1.0]
23+
data = [exp(M, p, σ * rand(M; vector_at=p)) for i in 1:n];
24+
25+
# objective function
26+
f(M, p) = sum(1 / (2 * n) * distance.(Ref(M), Ref(p), data) .^ 2)
27+
f_(p) = f(M,p)
28+
29+
# manual gradient
30+
# grad_f(M, p) = sum(1 / n * grad_distance.(Ref(M), data, Ref(p)));
31+
32+
# non-manual: intrinsic finite differences gradient
33+
grad_f_FD(M,p) = ManifoldDiff.gradient(M, f_, p, r_backend)
34+
35+
# and solve
36+
m1 = gradient_descent(M, f, grad_f_FD, data[1])
37+
38+
@test isapprox(p, m1; atol=0.15)
39+
40+
##
41+
end
42+
43+
##

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ include("testEuclidDistance.jl")
1818
include("testSphereMani.jl")
1919
include("testSpecialOrthogonalMani.jl")
2020
include("testBasicManifolds.jl")
21+
include("manifolds/manifolddiff.jl")
2122

2223
# start as basic as possible and build from there
2324
include("typeReturnMemRef.jl")

0 commit comments

Comments
 (0)