Skip to content

Commit bdc0ea3

Browse files
committed
wip fixing tests for staticarrays
1 parent 68c9066 commit bdc0ea3

File tree

6 files changed

+30
-14
lines changed

6 files changed

+30
-14
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ ApproxManifoldProducts = "0.7"
5656
BSON = "0.2, 0.3"
5757
Combinatorics = "1.0"
5858
DataStructures = "0.16, 0.17, 0.18"
59-
DistributedFactorGraphs = "0.21"
59+
DistributedFactorGraphs = "0.21, 0.22"
6060
Distributions = "0.24, 0.25"
6161
DocStringExtensions = "0.8, 0.9"
6262
FileIO = "1"
@@ -91,10 +91,11 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
9191
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
9292
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
9393
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
94+
LineSearches = ""d3d80556-e9d4-5f37-9878-2ab0fcc64255"
9495
Manopt = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5"
9596
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
9697
Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc"
9798
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9899

99100
[targets]
100-
test = ["DifferentialEquations", "Flux", "Graphs", "Manopt", "InteractiveUtils", "Interpolations", "Pkg", "Rotations", "Test"]
101+
test = ["DifferentialEquations", "Flux", "Graphs", "Manopt", "InteractiveUtils", "Interpolations", "LineSearches", "Pkg", "Rotations", "Test"]

src/ManifoldSampling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ end
3030
function sampleTangent(
3131
M::AbstractDecoratorManifold,
3232
z::Distribution,
33-
p = getPointIdentity(M),
33+
p = identity_element(M), #getPointIdentity(M),
3434
)
3535
return hat(M, p, rand(z, 1)[:]) #TODO find something better than (z,1)[:]
3636
end

src/VariableStatistics.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,13 @@ function Statistics.cov(
1818
return cov(getManifold(vartype), ptsArr; basis, kwargs...)
1919
end
2020

21-
function calcStdBasicSpread(vartype::InferenceVariable, ptsArr::Vector{P}) where {P}
22-
σ = std(vartype, ptsArr)
21+
function calcStdBasicSpread(vartype::InferenceVariable, ptsArr::AbstractVector) # {P}) where {P}
22+
_makemutable(s) = s
23+
_makemutable(s::StaticArray{Tuple{S},T,N}) where {S,T,N} = MArray{Tuple{S},T,N,S}(s)
24+
_makemutable(s::SMatrix{N,N,T,D}) where {N,T,D} = MMatrix{N,N,T,D}(s)
25+
26+
# silly conversion since Manifolds.std internally replicates eltype ptsArr which doesn't work on StaticArrays
27+
σ = std(vartype, _makemutable.(ptsArr))
2328

2429
#if no std yet, set to 1
2530
msst = 1e-10 < σ ? σ : 1.0

test/testBasicManifolds.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ M = SpecialEuclidean(3)
1515
Mr = M.manifold[2]
1616
pPq = ArrayPartition(zeros(3), exp(Mr, Identity(Mr), hat(Mr, Identity(Mr), w)))
1717
rPc_ = exp(M, Identity(M), hat(M, Identity(M), [zeros(3);w]))
18-
rPc = ArrayPartition(rPc_.parts[1], rPc_.parts[2])
18+
rPc = ArrayPartition(rPc_.x[1], rPc_.x[2])
1919

2020
@test isapprox(pPq.x[1], rPc.x[1])
2121
@test isapprox(pPq.x[2], rPc.x[2])

test/testSpecialOrthogonalMani.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using DistributedFactorGraphs
22
using IncrementalInference
3+
using LineSearches
34
using Manifolds
45
using StaticArrays
56
using Test
@@ -14,13 +15,13 @@ Base.convert(::Type{<:Tuple}, M::SpecialOrthogonal{2}) = (:Circular,)
1415
Base.convert(::Type{<:Tuple}, ::IIF.InstanceType{SpecialOrthogonal{2}}) = (:Circular,)
1516

1617
# @defVariable SpecialOrthogonal2 SpecialOrthogonal(2) @MMatrix([1.0 0.0; 0.0 1.0])
17-
@defVariable SpecialOrthogonal2 SpecialOrthogonal(2) [1.0 0.0; 0.0 1.0]
18+
@defVariable SpecialOrthogonal2 SpecialOrthogonal(2) SMatrix{2,2}(1.0, 0.0, 0.0, 1.0)
1819

1920
M = getManifold(SpecialOrthogonal2)
2021
@test M == SpecialOrthogonal(2)
2122
pT = getPointType(SpecialOrthogonal2)
2223
# @test pT == MMatrix{2, 2, Float64, 4}
23-
@test pT == Matrix{Float64}
24+
@test pT == SMatrix{2,2,Float64,4}
2425
= getPointIdentity(SpecialOrthogonal2)
2526
@test== [1.0 0.0; 0.0 1.0]
2627

@@ -43,6 +44,7 @@ vnd = getVariableSolverData(fg, :x0)
4344

4445

4546
##
47+
4648
v1 = addVariable!(fg, :x1, SpecialOrthogonal2)
4749
mf = ManifoldFactor(SpecialOrthogonal(2), MvNormal([pi], [0.01]))
4850
f = addFactor!(fg, [:x0, :x1], mf)
@@ -69,13 +71,13 @@ Base.convert(::Type{<:Tuple}, M::SpecialOrthogonal{3}) = (:Euclid, :Euclid, :Euc
6971
Base.convert(::Type{<:Tuple}, ::IIF.InstanceType{SpecialOrthogonal{3}}) = (:Euclid, :Euclid, :Euclid)
7072

7173
# @defVariable SO3 SpecialOrthogonal(3) @MMatrix([1.0 0.0; 0.0 1.0])
72-
@defVariable SO3 SpecialOrthogonal(3) [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0]
74+
@defVariable SO3 SpecialOrthogonal(3) SMatrix{3,3}(1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0)
7375

7476
M = getManifold(SO3)
7577
@test M == SpecialOrthogonal(3)
7678
pT = getPointType(SO3)
7779
# @test pT == MMatrix{2, 2, Float64, 4}
78-
@test pT == Matrix{Float64}
80+
@test pT == SMatrix{3,3,Float64,9}
7981
= getPointIdentity(SO3)
8082
@test== [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0]
8183

@@ -98,6 +100,7 @@ points = sampleFactor(fg, :x0f1, 100)
98100
std(SpecialOrthogonal(3), points)
99101

100102
##
103+
101104
v1 = addVariable!(fg, :x1, SO3)
102105
mf = ManifoldFactor(SpecialOrthogonal(3), MvNormal([0.01,0.01,0.01], [0.01,0.01,0.01]))
103106
f = addFactor!(fg, [:x0, :x1], mf)
@@ -121,6 +124,10 @@ vnd = getVariableSolverData(fg, :x1)
121124
@test all(is_point.(Ref(M), vnd.val))
122125

123126

124-
IIF.solveGraphParametric!(fg)
127+
IIF.solveGraphParametric!(
128+
fg;
129+
algorithmkwargs=(;alphaguess = LineSearches.InitialStatic(), linesearch = LineSearches.MoreThuente())
130+
)
131+
125132
##
126133
end

test/testSphereMani.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,25 @@ using Manifolds
44
using StaticArrays
55
using Test
66

7+
import Manifolds: identity_element
8+
79
##
810

911
@testset "Test Sphere(2) prior and relative" begin
1012
##
1113

1214
#FIXME REMOVE! this is type piracy and not a good idea, for testing only!!!
13-
Manifolds.identity_element(::Sphere{2, ℝ}, p::Vector{Float64}) = Float64[1,0,0]
15+
Manifolds.identity_element(::Sphere{2, ℝ}) = SVector(1.0, 0.0, 0.0)
16+
Manifolds.identity_element(::Sphere{2, ℝ}, p::AbstractVector) = SVector(1.0, 0.0, 0.0) # Float64[1,0,0]
1417

1518
Base.convert(::Type{<:Tuple}, M::Sphere{2, ℝ}) = (:Euclid, :Euclid)
1619
Base.convert(::Type{<:Tuple}, ::IIF.InstanceType{Sphere{2, ℝ}}) = (:Euclid, :Euclid)
1720

18-
@defVariable Sphere2 Sphere(2) [1.0, 0.0, 0.0]
21+
@defVariable Sphere2 Sphere(2) SVector(1.0, 0.0, 0.0)
1922
M = getManifold(Sphere2)
2023
@test M == Sphere(2)
2124
pT = getPointType(Sphere2)
22-
@test pT == Vector{Float64}
25+
@test pT == SVector{3,Float64}
2326
= getPointIdentity(Sphere2)
2427
@test== [1.0, 0.0, 0.0]
2528

0 commit comments

Comments
 (0)