Skip to content

Commit 0757330

Browse files
authored
Improve parametric performance and Cleanup (#1788)
1 parent 65e1715 commit 0757330

File tree

5 files changed

+36
-91
lines changed

5 files changed

+36
-91
lines changed

src/entities/CalcFactor.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,27 +100,24 @@ end
100100

101101
struct CalcFactorResidual{
102102
FT <: AbstractFactor,
103-
C,
103+
N,
104104
D,
105-
L,
106-
P,
107105
MEAS <: AbstractArray,
108-
N
106+
L,
107+
C,
109108
} <: CalcFactor{FT}
110109
faclbl::Symbol
111110
factor::FT
112-
cache::C
113111
varOrder::NTuple{N, Symbol}
114112
varOrderIdxs::NTuple{N, Int}
115-
points::P #TODO remove or not?
116113
meas::MEAS
117-
::SMatrix{D, D, Float64, L} #TODO remove or not?
118114
sqrt_iΣ::SMatrix{D, D, Float64, L}
115+
cache::C
119116
end
120117

121-
_nvars(::CalcFactorResidual{FT, C, D, L, P, MEAS, N}) where {FT, C, D, L, P, MEAS, N} = N
118+
_nvars(::CalcFactorResidual{FT, N, D, MEAS, L, C}) where {FT, N, D, MEAS, L, C} = N
122119
# _typeof_meas(::CalcFactorManopt{FT, C, D, L, MEAS, N}) where {FT, C, D, L, MEAS, N} = MEAS
123-
DFG.getDimension(::CalcFactorResidual{FT, C, D, L, P, MEAS, N}) where {FT, C, D, L, P, MEAS, N} = D
120+
DFG.getDimension(::CalcFactorResidual{FT, N, D, MEAS, L, C}) where {FT, N, D, MEAS, L, C} = D
124121

125122
# workaround for issue #1781
126123
import Base: getproperty

src/manifolds/services/ManifoldsExtentions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ function Manifolds.get_vector!(M::NPowerManifold, Y, p, c, B::AbstractBasis)
4747
Y[i] = get_vector(
4848
M.manifold,
4949
Manifolds._read(M, rep_size, p, i),
50-
view(c, v_iter:(v_iter + dim - 1)),
50+
# view(c, v_iter:(v_iter + dim - 1)),
51+
SVector{dim}(view(c, v_iter:(v_iter + dim - 1))),
5152
B,
5253
)
5354
v_iter += dim

src/parametric/services/ParametricManopt.jl

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ function getVarIntLabelMap(vartypeslist::OrderedDict{DataType, Vector{Symbol}})
1414
return varIntLabel, varlabelsAP
1515
end
1616

17-
function CalcFactorResidual(fg, fct::DFGFactor, varIntLabel, points::Union{Nothing,ArrayPartition}=nothing)
17+
function CalcFactorResidual(fg, fct::DFGFactor, varIntLabel)
1818
fac_func = getFactorType(fct)
1919
varOrder = getVariableOrder(fct)
2020

@@ -29,22 +29,14 @@ function CalcFactorResidual(fg, fct::DFGFactor, varIntLabel, points::Union{Nothi
2929
sqrt_iΣ = convert(SMatrix{dims, dims}, sqrt(iΣ))
3030
cache = preambleCache(fg, getVariable.(fg, varOrder), getFactorType(fct))
3131

32-
if isnothing(points)
33-
points_view = nothing
34-
else
35-
points_view = @view points[varOrderIdxs]
36-
end
37-
3832
return CalcFactorResidual(
3933
fct.label,
4034
getFactorMechanics(fac_func),
41-
cache,
4235
tuple(varOrder...),
4336
tuple(varOrderIdxs...),
44-
points_view,
4537
meas,
46-
iΣ,
4738
sqrt_iΣ,
39+
cache,
4840
)
4941
end
5042

@@ -53,7 +45,7 @@ end
5345
CalcFactorResidualAP
5446
Create an `ArrayPartition` of `CalcFactorResidual`s.
5547
"""
56-
function CalcFactorResidualAP(fg::GraphsDFG, factorLabels::Vector{Symbol}, varIntLabel::OrderedDict{Symbol, Int64}, points)
48+
function CalcFactorResidualAP(fg::GraphsDFG, factorLabels::Vector{Symbol}, varIntLabel::OrderedDict{Symbol, Int64})
5749
factypes, typedict, alltypes = getFactorTypesCount(getFactor.(fg, factorLabels))
5850

5951
# skip non-numeric prior (MetaPrior)
@@ -63,7 +55,7 @@ function CalcFactorResidualAP(fg::GraphsDFG, factorLabels::Vector{Symbol}, varIn
6355

6456
parts = map(values(alltypes)) do labels
6557
map(getFactor.(fg, labels)) do fct
66-
CalcFactorResidual(fg, fct, varIntLabel, points)
58+
CalcFactorResidual(fg, fct, varIntLabel)
6759
end
6860
end
6961
parts_tuple = (parts...,)
@@ -76,10 +68,6 @@ function (cfm::CalcFactorResidual)(p)
7668
return cfm.sqrt_iΣ * cfm(meas, points...)
7769
end
7870

79-
function (cfm::CalcFactorResidual{T})() where T
80-
return cfm.sqrt_iΣ * cfm(cfm.meas, cfm.points...)
81-
end
82-
8371
# cost function f: M->ℝᵈ for Riemannian Levenberg-Marquardt
8472
struct CostFres_cond!{PT, CFT}
8573
points::PT
@@ -112,14 +100,13 @@ end
112100

113101
function calcFactorResVec!(
114102
x::Vector{T},
115-
cfm_part::Vector{<:CalcFactorResidual{FT}},
103+
cfm_part::Vector{<:CalcFactorResidual{FT, N, D}},
116104
p::AbstractArray{T},
117105
st::Int
118-
) where {T,FT}
119-
l = getDimension(cfm_part[1]) # all should be the same
106+
) where {T, FT, N, D}
120107
for cfm in cfm_part
121-
x[st:st + l - 1] = cfm(p) #NOTE looks like do not broadcast here
122-
st += l
108+
x[st:st + D - 1] = cfm(p) #NOTE looks like do not broadcast here
109+
st += D
123110
end
124111
return st
125112
end
@@ -329,7 +316,7 @@ function solve_RLM(
329316
end
330317

331318
# create an ArrayPartition{CalcFactorResidual} for faclabels
332-
calcfacs = CalcFactorResidualAP(fg, faclabels, varIntLabel, p0)
319+
calcfacs = CalcFactorResidualAP(fg, faclabels, varIntLabel)
333320

334321
#cost and jacobian functions
335322
# cost function f: M->ℝᵈ for Riemannian Levenberg-Marquardt
@@ -446,7 +433,7 @@ function solve_RLM_conditional(
446433
# varIntLabel_frontals = filter(p->first(p) in frontals, varIntLabel)
447434
# varIntLabel_separators = filter(p->first(p) in separators, varIntLabel)
448435

449-
calcfacs = CalcFactorResidualAP(fg, faclabels, all_varIntLabel, all_points)
436+
calcfacs = CalcFactorResidualAP(fg, faclabels, all_varIntLabel)
450437

451438
# get the manifold and variable types
452439

src/services/FactorGradients.jl

Lines changed: 16 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -5,67 +5,27 @@
55
function factorJacobian(
66
fg,
77
faclabel::Symbol,
8-
fro_p = first.(getVal.(fg, getVariableOrder(fg, faclabel), solveKey = :parametric))
8+
p0 = ArrayPartition(first.(getVal.(fg, getVariableOrder(fg, faclabel), solveKey = :parametric))...),
9+
backend = ManifoldDiff.TangentDiffBackend(ManifoldDiff.FiniteDiffBackend()),
910
)
10-
faclabels = Symbol[faclabel;]
11-
frontals = ls(fg, faclabel)
12-
separators = Symbol[] # setdiff(ls(fg), frontals)
1311

14-
# get the subgraph formed by all frontals, separators and fully connected factors
15-
varlabels = union(frontals, separators)
12+
fac = getFactor(fg, faclabel)
13+
varlabels = getVariableOrder(fac)
14+
varIntLabel = OrderedDict(zip(varlabels, eachindex(varlabels)))
1615

17-
filter!(faclabels) do fl
18-
return issubset(getVariableOrder(fg, fl), varlabels)
16+
cfm = CalcFactorResidual(fg, fac, varIntLabel)
17+
18+
function costf(p)
19+
points = map(idx->p.x[idx], cfm.varOrderIdxs)
20+
return cfm.sqrt_iΣ * cfm(cfm.meas, points...)
1921
end
2022

21-
facs = getFactor.(fg, faclabels)
22-
23-
# so the subgraph consists of varlabels(frontals + separators) and faclabels
24-
25-
varIntLabel = OrderedDict(zip(varlabels, collect(1:length(varlabels))))
26-
27-
# varIntLabel_frontals = filter(p->first(p) in frontals, varIntLabel)
28-
# varIntLabel_separators = filter(p->first(p) in separators, varIntLabel)
29-
30-
calcfacs = map(f->IIF.CalcFactorManopt(f, varIntLabel), facs)
31-
32-
# get the manifold and variable types
33-
frontal_vars = getVariable.(fg, frontals)
34-
vartypes, vartypecount, vartypeslist = IIF.getVariableTypesCount(frontal_vars)
35-
36-
PMs = map(vartypes) do vartype
37-
N = vartypecount[vartype]
38-
G = getManifold(vartype)
39-
return IIF.NPowerManifold(G, N)
40-
end
41-
M = ProductManifold(PMs...)
42-
43-
#
44-
#FIXME
45-
@assert length(M.manifolds) == 1 "#FIXME, this only works with 1 manifold type component"
46-
MM = M.manifolds[1]
47-
48-
# inital values and separators from fg
49-
# fro_p = first.(getVal.(frontal_vars, solveKey = :parametric))
50-
# sep_p::Vector{eltype(fro_p)} = first.(getVal.(fg, separators, solveKey = :parametric))
51-
sep_p = Vector{eltype(fro_p)}()
52-
53-
#cost and jacobian functions
54-
# cost function f: M->ℝᵈ for Riemannian Levenberg-Marquardt
55-
costF! = IIF.CostF_RLM!(calcfacs, fro_p, sep_p)
56-
# jacobian of function for Riemannian Levenberg-Marquardt
57-
jacF! = IIF.JacF_RLM!(MM, costF!)
58-
59-
num_components = length(jacF!.res)
60-
61-
p0 = deepcopy(fro_p)
62-
63-
# initial_residual_values = zeros(num_components)
64-
J = zeros(num_components, manifold_dimension(MM))
65-
66-
jacF!(MM, J, p0)
67-
68-
J
23+
M_dom = ProductManifold(getManifold.(fg, varlabels)...)
24+
#TODO verify M_codom
25+
M_codom = Euclidean(manifold_dimension(getManifold(fac)))
26+
# Jx(M, p) = ManifoldDiff.jacobian(M, M_codom, calcfac, p, backend)
27+
28+
return ManifoldDiff.jacobian(M_dom, M_codom, costf, p0, backend)
6929
end
7030

7131

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ end
1313
if TEST_GROUP in ["all", "basic_functional_group"]
1414
# more frequent stochasic failures from numerics
1515
include("manifolds/manifolddiff.jl")
16-
# include("manifolds/factordiff.jl") #FIXME restore
16+
include("manifolds/factordiff.jl")
1717
include("testSpecialEuclidean2Mani.jl")
1818
include("testEuclidDistance.jl")
1919

@@ -99,7 +99,7 @@ include("testFluxModelsDistribution.jl")
9999
include("testAnalysisTools.jl")
100100

101101
include("testBasicParametric.jl")
102-
# include("testMixtureParametric.jl") #FIXME parametric mixtures #[TODO open issue]
102+
# include("testMixtureParametric.jl") #FIXME parametric mixtures #1787
103103

104104
# dont run test on ARM, as per issue #527
105105
if Base.Sys.ARCH in [:x86_64;]

0 commit comments

Comments
 (0)