Skip to content

Commit 5f2c0d8

Browse files
authored
Merge pull request #1725 from JuliaRobotics/23Q2/twig/manopt
wip w partials, sampleTangent
2 parents 0b1ea37 + 98ba433 commit 5f2c0d8

12 files changed

+70
-36
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ ApproxManifoldProducts = "0.7"
5656
BSON = "0.2, 0.3"
5757
Combinatorics = "1.0"
5858
DataStructures = "0.16, 0.17, 0.18"
59-
DistributedFactorGraphs = "0.21"
59+
DistributedFactorGraphs = "0.21, 0.22"
6060
Distributions = "0.24, 0.25"
6161
DocStringExtensions = "0.8, 0.9"
6262
FileIO = "1"
@@ -91,10 +91,11 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
9191
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
9292
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
9393
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
94+
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
9495
Manopt = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5"
9596
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
9697
Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc"
9798
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9899

99100
[targets]
100-
test = ["DifferentialEquations", "Flux", "Graphs", "Manopt", "InteractiveUtils", "Interpolations", "Pkg", "Rotations", "Test"]
101+
test = ["DifferentialEquations", "Flux", "Graphs", "Manopt", "InteractiveUtils", "Interpolations", "LineSearches", "Pkg", "Rotations", "Test"]

src/Factors/GenericFunctions.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,14 @@ DFG.getManifold(f::ManifoldFactor) = f.M
8080
function getSample(cf::CalcFactor{<:ManifoldFactor{M, Z}}) where {M, Z}
8181
#TODO @assert dim == cf.factor.Z's dimension
8282
#TODO investigate use of SVector if small dims
83-
if M isa ManifoldKernelDensity
84-
ret = sample(cf.factor.Z.belief)[1]
85-
else
86-
ret = rand(cf.factor.Z)
87-
end
88-
# ret = sampleTangent(M, cf.factor.Z)
83+
# if M isa ManifoldKernelDensity
84+
# ret = sample(cf.factor.Z.belief)[1]
85+
# else
86+
# ret = rand(cf.factor.Z)
87+
# end
88+
89+
# ASSUME this function is only used for RelativeFactors which must use measurements as tangents
90+
ret = sampleTangent(cf.factor.M, cf.factor.Z)
8991
#return coordinates as we do not know the point here #TODO separate Lie group
9092
return ret
9193
end

src/ManifoldSampling.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,15 @@ function sampleTangent(x::ManifoldKernelDensity, p = mean(x))
2323
end
2424

2525
# Sampling Distributions
26-
function sampleTangent(M::AbstractManifold, z::Distribution, p, basis::AbstractBasis)
26+
# assumes M is a group and will break for Riemannian, but leaving that enhancement as TODO
27+
function sampleTangent(M::AbstractManifold, z::Distribution, p = getPointIdentity(M), basis::AbstractBasis = DefaultOrthogonalBasis())
2728
return get_vector(M, p, rand(z), basis)
2829
end
2930

3031
function sampleTangent(
3132
M::AbstractDecoratorManifold,
3233
z::Distribution,
33-
p = getPointIdentity(M),
34+
p = identity_element(M), #getPointIdentity(M),
3435
)
3536
return hat(M, p, rand(z, 1)[:]) #TODO find something better than (z,1)[:]
3637
end

src/VariableStatistics.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,13 @@ function Statistics.cov(
1818
return cov(getManifold(vartype), ptsArr; basis, kwargs...)
1919
end
2020

21-
function calcStdBasicSpread(vartype::InferenceVariable, ptsArr::Vector{P}) where {P}
22-
σ = std(vartype, ptsArr)
21+
function calcStdBasicSpread(vartype::InferenceVariable, ptsArr::AbstractVector) # {P}) where {P}
22+
_makemutable(s) = s
23+
_makemutable(s::StaticArray{Tuple{S},T,N}) where {S,T,N} = MArray{Tuple{S},T,N,S}(s)
24+
_makemutable(s::SMatrix{N,N,T,D}) where {N,T,D} = MMatrix{N,N,T,D}(s)
25+
26+
# FIXME, silly conversion since Manifolds.std internally replicates eltype ptsArr which doesn't work on StaticArrays
27+
σ = std(vartype, _makemutable.(ptsArr))
2328

2429
#if no std yet, set to 1
2530
msst = 1e-10 < σ ? σ : 1.0

src/Variables/DefaultVariables.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function DFG.getDimension(val::InstanceType{Position{N}}) where {N}
1616
return manifold_dimension(getManifold(val))
1717
end
1818
DFG.getPointType(::Type{Position{N}}) where {N} = Vector{Float64}
19-
DFG.getPointIdentity(M_::Type{Position{N}}) where {N} = zeros(N) # identity_element(getManifold(M_), zeros(N))
19+
DFG.getPointIdentity(M_::Type{Position{N}}) where {N} = @SVector(zeros(N)) # identity_element(getManifold(M_), zeros(N))
2020

2121
function Base.convert(
2222
::Type{<:ManifoldsBase.AbstractManifold},

src/services/FGOSUtils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ function manikde!(
122122
) where {P <: Union{<:AbstractArray, <:Number, <:Manifolds.ArrayPartition}}
123123
#
124124
M = getManifold(variableType)
125+
# @info "pts" P typeof(pts[1]) pts[1]
125126
infoPerCoord = ones(AMP.getNumberCoords(M, pts[1]))
126127
return AMP.manikde!(M, pts; infoPerCoord, kw...)
127128
end

test/manifolds/manifolddiff.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ end
9292
##
9393

9494
sol = Optim.optimize(f, g!, x0, Optim.ConjugateGradient(; manifold=ManifoldWrapper(M)))
95-
@test isapprox([0,1,0.], sol.minimizer; atol=1e-8)
95+
@test isapprox([0,1,0.], sol.minimizer; atol=1e-6)
9696

9797

9898
## finitediff gradient (non-manual)
@@ -106,7 +106,7 @@ end
106106
x0 = [1.0, 0.0, 0.0]
107107

108108
sol = Optim.optimize(f, g_FD!, x0, Optim.ConjugateGradient(; manifold=ManifoldWrapper(M)))
109-
@test isapprox([0,1,0.], sol.minimizer; atol=1e-8)
109+
@test isapprox([0,1,0.], sol.minimizer; atol=1e-6)
110110

111111
##
112112

@@ -161,7 +161,7 @@ Cq .= randn(3)
161161
# Cq[
162162
@show sol.minimizer
163163
@test isapprox( f(sol.minimizer), 0; atol=1e-3 )
164-
@test isapprox( 0, sum(abs.(log(M, e0, compose(M, inv(M,q), sol.minimizer)))); atol=1e-5)
164+
@test isapprox( 0, sum(abs.(log(M, e0, compose(M, inv(M,q), sol.minimizer)))); atol=1e-3)
165165

166166
##
167167
end
@@ -193,7 +193,7 @@ g_FD!(X, q)
193193

194194
@show X_ = [X.x[1][:]; X.x[2][:]]
195195
# gradient at the optimal point should be zero
196-
@test isapprox(0, sum(abs.(X_)); atol=1e-8 )
196+
@test isapprox(0, sum(abs.(X_)); atol=1e-6 )
197197

198198
# gradient not the optimal point should be non-zero
199199
g_FD!(X, e0)
@@ -230,7 +230,7 @@ sol = IncrementalInference.optimizeManifold_FD(M,f,x0)
230230

231231
@show sol.minimizer
232232
@test isapprox( f(sol.minimizer), 0; atol=1e-3 )
233-
@test isapprox( 0, sum(abs.(log(M, e0, compose(M, inv(M,q), sol.minimizer)))); atol=1e-4)
233+
@test isapprox( 0, sum(abs.(log(M, e0, compose(M, inv(M,q), sol.minimizer)))); atol=1e-3)
234234

235235

236236
##

test/testBasicManifolds.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ M = SpecialEuclidean(3)
1515
Mr = M.manifold[2]
1616
pPq = ArrayPartition(zeros(3), exp(Mr, Identity(Mr), hat(Mr, Identity(Mr), w)))
1717
rPc_ = exp(M, Identity(M), hat(M, Identity(M), [zeros(3);w]))
18-
rPc = ArrayPartition(rPc_.parts[1], rPc_.parts[2])
18+
rPc = ArrayPartition(rPc_.x[1], rPc_.x[2])
1919

2020
@test isapprox(pPq.x[1], rPc.x[1])
2121
@test isapprox(pPq.x[2], rPc.x[2])

test/testBasicParametric.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ for i in 0:10
139139
sym = Symbol("x",i)
140140
var = getVariable(fg,sym)
141141
@show val = var.solverDataDict[:parametric].val
142-
@test isapprox(val[1][1], i, atol=1e-4)
143-
@test isapprox(val[1][2], i, atol=1e-4)
142+
@test isapprox(val[1][1], i, atol=1e-3)
143+
@test isapprox(val[1][2], i, atol=1e-3)
144144
end
145145

146146
##
@@ -179,9 +179,9 @@ foreach(fct->println(fct.label, ": ", getFactorType(fct).Z), getFactors(fg))
179179
d,st,vs,Σ = IIF.solveGraphParametric(fg)
180180

181181
foreach(println, d)
182-
@test isapprox(d[:x0].val[1][1], -0.01, atol=1e-4)
183-
@test isapprox(d[:x1].val[1][1], 0.0, atol=1e-4)
184-
@test isapprox(d[:x2].val[1][1], 0.01, atol=1e-4)
182+
@test isapprox(d[:x0].val[1][1], -0.01, atol=1e-3)
183+
@test isapprox(d[:x1].val[1][1], 0.0, atol=1e-3)
184+
@test isapprox(d[:x2].val[1][1], 0.01, atol=1e-3)
185185

186186

187187
##
@@ -202,9 +202,9 @@ tree2 = solveTree!(fg; algorithm=:parametric, eliminationOrder=[:x0, :x2, :x1])
202202
# end
203203
foreach(v->println(v.label, ": ", DFG.getSolverData(v, :parametric).val), getVariables(fg))
204204

205-
@test isapprox(getVariable(fg,:x0).solverDataDict[:parametric].val[1][1], -0.01, atol=1e-4)
206-
@test isapprox(getVariable(fg,:x1).solverDataDict[:parametric].val[1][1], 0.0, atol=1e-4)
207-
@test isapprox(getVariable(fg,:x2).solverDataDict[:parametric].val[1][1], 0.01, atol=1e-4)
205+
@test isapprox(getVariable(fg,:x0).solverDataDict[:parametric].val[1][1], -0.01, atol=1e-3)
206+
@test isapprox(getVariable(fg,:x1).solverDataDict[:parametric].val[1][1], 0.0, atol=1e-3)
207+
@test isapprox(getVariable(fg,:x2).solverDataDict[:parametric].val[1][1], 0.01, atol=1e-3)
208208

209209
## ##############################################################################
210210
## multiple sections

test/testSpecialEuclidean2Mani.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,9 +388,23 @@ doautoinit!(fg, :x0)
388388
@test length(getPoints(getBelief(fg, :x0))) == getSolverParams(fg).N # 120
389389
# @info "PassThrough transfers the full point count to the graph, unless a product is calculated during the propagateBelief step."
390390

391+
392+
393+
## check the partials magic
394+
395+
dens, ipc = propagateBelief(fg,:x0,:)
396+
testv = deepcopy(getVariable(fg, :x0))
397+
setBelief!(testv, dens, true, ipc)
398+
399+
391400
##
392401

393-
solveGraph!(fg);
402+
smtasks = Task[]
403+
solveGraph!(fg; smtasks);
404+
# hists = fetchCliqHistoryAll!(smtasks)
405+
# printCSMHistoryLogical(hists)
406+
# hists_ = deepcopy(hists)
407+
# repeatCSMStep!(hists, 1, 6)
394408

395409
@test 120 == length(getPoints(fg, :x0))
396410

0 commit comments

Comments
 (0)