Skip to content

Commit 1c24a1f

Browse files
committed
AbstractCalcFactor and some manopt
1 parent af6e166 commit 1c24a1f

File tree

8 files changed

+63
-46
lines changed

8 files changed

+63
-46
lines changed

src/ExportAPI.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ export CSMHistory,
236236

237237
# Factor operational memory
238238
CommonConvWrapper,
239+
AbstractCalcFactor,
239240
CalcFactor,
240241
getCliqVarInitOrderUp,
241242
getCliqNumAssocFactorsPerVar,

src/Factors/Circular.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ CircularCircular(::UniformScaling) = CircularCircular(Normal())
2121

2222
DFG.getManifold(::CircularCircular) = RealCircleGroup()
2323

24-
function (cf::CalcFactor{<:CircularCircular})(X, p, q)
24+
function (cf::AbstractCalcFactor{<:CircularCircular})(X, p, q)
2525
#
2626
M = cf.manifold # getManifold(cf.factor)
2727
return distanceTangent2Point(M, X, p, q)

src/Factors/DefaultPrior.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ getManifold(pr::Prior) = TranslationGroup(getDimension(pr.Z))
1414
# getSample(cf::CalcFactor{<:Prior}) = rand(cf.factor.Z)
1515

1616
# basic default
17-
(s::CalcFactor{<:Prior})(z, x1) = z .- x1
17+
(s::AbstractCalcFactor{<:Prior})(z, x1) = z .- x1
1818

1919
## packed types are still developed by hand. Future versions would likely use a @packable macro to write Protobuf safe versions of factors
2020

src/Factors/EuclidDistance.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ getManifold(::InstanceType{EuclidDistance}) = TranslationGroup(1)
1717
getDimension(::InstanceType{<:EuclidDistance}) = 1
1818

1919
# new and simplified interface for both nonparametric and parametric
20-
(s::CalcFactor{<:EuclidDistance})(z, x1, x2) = z .- norm(x2 .- x1)
20+
(s::AbstractCalcFactor{<:EuclidDistance})(z, x1, x2) = z .- norm(x2 .- x1)
2121

2222
function Base.convert(::Type{<:MB.AbstractManifold}, ::InstanceType{EuclidDistance})
2323
return Manifolds.TranslationGroup(1)

src/Factors/LinearRelative.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ getManifold(::InstanceType{LinearRelative{N}}) where {N} = getManifold(Continuou
3939
getDimension(::InstanceType{LinearRelative{N}}) where {N} = N
4040

4141
# new and simplified interface for both nonparametric and parametric
42-
function (s::CalcFactor{<:LinearRelative})(z, x1, x2)
42+
function (s::AbstractCalcFactor{<:LinearRelative})(z, x1, x2)
4343
# TODO convert to distance(distance(x2,x1),z) # or use dispatch on `-` -- what to do about `.-`
4444
# if s._sampleIdx < 5
4545
# @info "LinearRelative" s._sampleIdx "$z" "$x1" "$x2" s.solvefor getLabel.(s.fullvariables)

src/Factors/MsgPrior.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ end
3232
getManifold(mp::MsgPrior{<:ManifoldKernelDensity}) = mp.Z.manifold
3333
getManifold(mp::MsgPrior) = mp.M
3434

35-
(cfo::CalcFactor{<:MsgPrior})(z, x1) = z .- x1
35+
#FIXME this will not work on manifolds
36+
(cfo::AbstractCalcFactor{<:MsgPrior})(z, x1) = z .- x1
3637

3738
Base.@kwdef struct PackedMsgPrior <: AbstractPackedFactor
3839
Z::PackedSamplableBelief

src/entities/CalcFactor.jl

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
abstract type AbstractMaxMixtureSolver end
33

44

5+
abstract type AbstractCalcFactor{T<:AbstractFactor} end
6+
57
"""
68
$TYPEDEF
79
@@ -32,7 +34,7 @@ struct CalcFactor{
3234
C,
3335
VT <: Tuple,
3436
M <: AbstractManifold
35-
}
37+
} <: AbstractCalcFactor{FT}
3638
""" the interface compliant user object functor containing the data and logic """
3739
factor::FT
3840
""" what is the sample (particle) id for which the residual is being calculated """
@@ -65,32 +67,57 @@ Related
6567
6668
[`CalcFactor`](@ref)
6769
"""
68-
struct CalcFactorMahalanobis{N, D, L, S <: Union{Nothing, AbstractMaxMixtureSolver}}
70+
struct CalcFactorMahalanobis{
71+
FT,
72+
N,
73+
C,
74+
MEAS<:AbstractArray,
75+
D,
76+
L,
77+
S <: Union{Nothing, AbstractMaxMixtureSolver}
78+
} <: AbstractCalcFactor{FT}
6979
faclbl::Symbol
70-
calcfactor!::CalcFactor
80+
factor::FT
81+
cache::C
7182
varOrder::Vector{Symbol}
72-
meas::NTuple{N, <:AbstractArray}
83+
meas::NTuple{N, MEAS}
7384
::NTuple{N, SMatrix{D, D, Float64, L}}
7485
specialAlg::S
7586
end
7687

7788

89+
# struct CalcFactorMahalanobis{N, D, L, S <: Union{Nothing, AbstractMaxMixtureSolver}} <: AbstractCalcFactor{FT}
90+
# faclbl::Symbol
91+
# calcfactor!::CalcFactor
92+
# varOrder::Vector{Symbol}
93+
# meas::NTuple{N, <:AbstractArray}
94+
# iΣ::NTuple{N, SMatrix{D, D, Float64, L}}
95+
# specialAlg::S
96+
# end
7897

79-
98+
#rename to CalcFactorResidual
8099
struct CalcFactorManopt{
100+
FT <: AbstractFactor,
101+
C,
81102
D,
82103
L,
83-
FT <: AbstractFactor,
84-
M <: AbstractManifold,
104+
P,
85105
MEAS <: AbstractArray,
86-
}
106+
N
107+
} <: AbstractCalcFactor{FT}
87108
faclbl::Symbol
88-
calcfactor!::CalcFactor{FT, Nothing, Nothing, Tuple{}, M}
89-
varOrder::Vector{Symbol}
90-
varOrderIdxs::Vector{Int}
109+
factor::FT
110+
cache::C
111+
varOrder::NTuple{N, Symbol}
112+
varOrderIdxs::NTuple{N, Int}
113+
points::P
91114
meas::MEAS
92115
::SMatrix{D, D, Float64, L}
93116
sqrt_iΣ::SMatrix{D, D, Float64, L}
94117
end
95118

119+
_nvars(::CalcFactorManopt{FT, C, D, L, P, MEAS, N}) where {FT, C, D, L, P, MEAS, N} = N
120+
# _typeof_meas(::CalcFactorManopt{FT, C, D, L, MEAS, N}) where {FT, C, D, L, MEAS, N} = MEAS
121+
DFG.getDimension(::CalcFactorManopt{FT, C, D, L, P, MEAS, N}) where {FT, C, D, L, P, MEAS, N} = D
122+
96123

src/parametric/services/ParametricManoptDev.jl

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using SparseArrays
1111

1212

1313

14-
function CalcFactorManopt(fct::DFGFactor, varIntLabel)
14+
function CalcFactorManopt(fg, fct::DFGFactor, varIntLabel)
1515
fac_func = getFactorType(fct)
1616
varOrder = getVariableOrder(fct)
1717

@@ -20,42 +20,28 @@ function CalcFactorManopt(fct::DFGFactor, varIntLabel)
2020
M = getManifold(getFactorType(fct))
2121

2222
dims = manifold_dimension(M)
23-
ϵ = getPointIdentity(M)
24-
25-
_meas, _iΣ = getMeasurementParametric(fct)
26-
if fac_func isa ManifoldPrior
27-
meas = _meas # already a point on M
28-
elseif fac_func isa AbstractPrior
29-
X = get_vector(M, ϵ, _meas, DefaultOrthonormalBasis())
30-
meas = exp(M, ϵ, X) # convert to point on M
31-
else
32-
# its a relative factor so should be a tangent vector
33-
meas = convert(typeof(ϵ), get_vector(M, ϵ, _meas, DefaultOrthonormalBasis()))
34-
end
3523

36-
# make sure its an SMatrix
37-
= convert(SMatrix{dims, dims}, _iΣ)
24+
meas, iΣ = getFactorMeasurementParametric(fct)
3825

3926
sqrt_iΣ = convert(SMatrix{dims, dims}, sqrt(iΣ))
40-
# cache = preambleCache(fg, getVariable.(fg, varOrder), getFactorType(fct))
27+
cache = preambleCache(fg, getVariable.(fg, varOrder), getFactorType(fct))
4128

42-
calcf = CalcFactor(
29+
return CalcFactorManopt(
30+
fct.label,
4331
getFactorMechanics(fac_func),
44-
0,
45-
nothing,
46-
true,
47-
nothing,#cache,
48-
(), #DFGVariable[],
49-
0,
50-
getManifold(fac_func),
32+
cache,
33+
varOrder,
34+
varOrderIdxs,
35+
meas,
36+
iΣ,
37+
sqrt_iΣ,
5138
)
52-
return CalcFactorManopt(fct.label, calcf, varOrder, varOrderIdxs, meas, iΣ, sqrt_iΣ)
5339
end
5440

5541
function (cfm::CalcFactorManopt)(p)
5642
meas = cfm.meas
5743
idx = cfm.varOrderIdxs
58-
return cfm.sqrt_iΣ * cfm.calcfactor!(meas, p[idx]...)
44+
return cfm.sqrt_iΣ * cfm(meas, p[idx]...)
5945
end
6046

6147
# cost function f: M->ℝᵈ for Riemannian Levenberg-Marquardt
@@ -83,7 +69,8 @@ struct JacF_RLM!{CF, T}
8369
res::Vector{Float64}
8470
end
8571

86-
function JacF_RLM!(M, costF!; basis_domain::AbstractBasis = DefaultOrthonormalBasis())
72+
# function JacF_RLM!(M, costF!; basis_domain::AbstractBasis = DefaultOrthonormalBasis())
73+
function JacF_RLM!(M, costF!; basis_domain::AbstractBasis = DefaultOrthogonalBasis())
8774

8875
p = costF!.points
8976

@@ -108,7 +95,8 @@ function (jacF!::JacF_RLM!)(
10895
M::AbstractManifold,
10996
J,
11097
p;
111-
basis_domain::AbstractBasis = DefaultOrthonormalBasis(),
98+
# basis_domain::AbstractBasis = DefaultOrthonormalBasis(),
99+
basis_domain::AbstractBasis = DefaultOrthogonalBasis(),
112100
)
113101

114102
X0 = jacF!.X0
@@ -196,7 +184,7 @@ function solve_RLM(
196184
# varIntLabel_frontals = filter(p->first(p) in frontals, varIntLabel)
197185
# varIntLabel_separators = filter(p->first(p) in separators, varIntLabel)
198186

199-
calcfacs = map(f->CalcFactorManopt(f, varIntLabel), facs)
187+
calcfacs = map(f->CalcFactorManopt(fg, f, varIntLabel), facs)
200188

201189
# get the manifold and variable types
202190
frontal_vars = getVariable.(fg, frontals)
@@ -262,7 +250,7 @@ function solve_RLM_sparse(fg; kwargs...)
262250

263251
varIntLabel = OrderedDict(zip(varlabels, collect(1:length(varlabels))))
264252

265-
calcfacs = CalcFactorManopt.(facs, Ref(varIntLabel))
253+
calcfacs = CalcFactorManopt.(fg, facs, Ref(varIntLabel))
266254

267255
# get the manifold and variable types
268256
vars = getVariable.(fg, varlabels)

0 commit comments

Comments
 (0)