Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion IncrementalInference/src/IncrementalInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ import ApproxManifoldProducts: isPartial
import ApproxManifoldProducts: _update!
import DistributedFactorGraphs: addVariable!, addFactor!, ls, lsf, isInitialized
import DistributedFactorGraphs: compare
import DistributedFactorGraphs: rebuildFactorCache!
import DistributedFactorGraphs: getDimension, getManifold, getPointType, getPointIdentity
import DistributedFactorGraphs: getPoint, getCoordinates
import DistributedFactorGraphs: getStateKind
Expand Down
23 changes: 19 additions & 4 deletions IncrementalInference/src/entities/BeliefTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,24 @@ function TreeBelief(
end

function TreeBelief(vnd::State, solvDim::Real = 0)
TreeBelief(DFG.getDensityKind(vnd), vnd, solvDim)
end

function TreeBelief(::DFG.GaussianDensityKind, vnd::State, solvDim::Real = 0)
return TreeBelief(
DFG.refMeans(vnd),
DFG.refCovariances(vnd)[1],
vnd.observability,
getStateKind(vnd),
getManifold(vnd),
solvDim,
)
end

function TreeBelief(::DFG.NonparametricDensityKind, vnd::State, solvDim::Real = 0)
return TreeBelief(
vnd.val,
vnd.bw,
DFG.refPoints(vnd),
DFG.refBandwidth(vnd),
vnd.observability,
getStateKind(vnd),
getManifold(vnd),
Expand All @@ -93,9 +108,9 @@ function TreeBelief(vari::VariableCompute, solveKey::Symbol = :default; solvable
end
#

getStateKind(tb::TreeBelief) = tb.variableType
DFG.getStateKind(tb::TreeBelief) = tb.variableType

getManifold(treeb::TreeBelief) = getManifold(treeb.variableType)
DFG.getManifold(treeb::TreeBelief) = getManifold(treeb.variableType)

function compare(t1::TreeBelief, t2::TreeBelief)
TP = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ function solveUp_ParametricStateMachine(csmc::CliqStateMachineContainer)
vnd = getState(getVariable(csmc.cliqSubFg, v), :parametric)
# fill in the variable node data value
logCSM(csmc, "$(csmc.cliq.id) up: updating $v : $val")
vnd.val[1] = val.val
DFG.refMeans(vnd)[1] = val.val
#calculate and fill in covariance
#TODO rather broadcast than make new memory
vnd.bw = val.cov
DFG.refCovariances(vnd)[1] = val.cov
end
# elseif length(lsfPriors(csmc.cliqSubFg)) == 0 #FIXME
# @error "Par-3, clique $(csmc.cliq.id) failed to converge in upsolve, but ignoring since no priors" result
Expand Down Expand Up @@ -118,8 +118,8 @@ function solveDown_ParametricStateMachine(csmc::CliqStateMachineContainer)
#TODO maybe combine variable and factor in new prior?
vnd = getState(getVariable(csmc.cliqSubFg, msym), :parametric)
logCSM(csmc, "$(csmc.cliq.id): Updating separator $msym from message $(belief.val)")
vnd.val .= belief.val
vnd.bw .= belief.bw
DFG.refMeans(vnd)[1] = belief.val[1] #FIXME 🦨 shares data structure in belief
DFG.refCovariances(vnd)[1] = belief.bw
end
end
end
Expand All @@ -146,8 +146,8 @@ function solveDown_ParametricStateMachine(csmc::CliqStateMachineContainer)
logCSM(csmc, "$(csmc.cliq.id) down: updating $v : $val"; loglevel = Logging.Info)
vnd = getState(getVariable(csmc.cliqSubFg, v), :parametric)
#Update subfg variables
vnd.val[1] = val.val
vnd.bw .= val.cov
DFG.refMeans(vnd)[1] = val.val
DFG.refCovariances(vnd)[1] = val.cov
end
else
@error "Par-5, clique $(csmc.cliq.id) failed to converge in down solve" result
Expand Down
16 changes: 7 additions & 9 deletions IncrementalInference/src/parametric/services/ParametricManopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ function solve_RLM(

#Can use varIntLabel (because its an OrderedDict), but varLabelsAP makes the ArrayPartition.
p0 = map(varlabelsAP) do label
getVal(fg, label; solveKey)[1]
DFG.refMeans(getState(fg, label, solveKey))[1]
end

# create an ArrayPartition{CalcFactorResidual} for faclabels
Expand Down Expand Up @@ -415,7 +415,7 @@ function solve_RLM_conditional(

# get the subgraph formed by all frontals, separators and fully connected factors
varlabels = union(frontals, separators)
faclabels = sortDFG(setdiff(listNeighborhood(fg, varlabels, 1), varlabels))
_, faclabels = listNeighborhood(fg, varlabels, 1)

filter!(faclabels) do fl
return issubset(getVariableOrder(fg, fl), varlabels)
Expand All @@ -440,7 +440,7 @@ function solve_RLM_conditional(
all_varlabelsAP = ArrayPartition((frontal_varlabelsAP.x..., separator_varlabelsAP.x...))

all_points = map(all_varlabelsAP) do label
getVal(fg, label; solveKey)[1]
DFG.refMeans(getState(fg, label, solveKey))[1]
end

p0 = ArrayPartition(all_points.x[1:length(frontal_varlabelsAP.x)])
Expand Down Expand Up @@ -547,12 +547,10 @@ function autoinitParametric!(
return false
end

vnd::State = getState(xi, solveKey)

if perturb_point
_M = getManifold(xi)
p = vnd.val[1]
vnd.val[1] = exp(
p = DFG.refMeans(vnd)[1]
DFG.refMeans(vnd)[1] = exp(
_M,
p,
get_vector(
Expand All @@ -566,9 +564,9 @@ function autoinitParametric!(
M, vartypeslist, lm_r, Σ = solve_RLM_conditional(dfg, [initme], initfrom; solveKey, kwargs...)

val = lm_r[1]
vnd.val[1] = val
DFG.refMeans(vnd)[1] = val

!isnothing(Σ) && (vnd.bw .= Σ)
!isnothing(Σ) && (DFG.refCovariances(vnd)[1] .= Σ)

# updateSolverDataParametric!(vnd, val, Σ)

Expand Down
22 changes: 12 additions & 10 deletions IncrementalInference/src/parametric/services/ParametricUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ function initPoints!(p, gsc, fg::AbstractDFG, solveKey = :parametric)
for (i, vartype) in enumerate(gsc.varTypes)
varIds = gsc.varTypesIds[vartype]
for (j, vId) in enumerate(varIds)
p[gsc.M, i][j] = getState(fg, vId, solveKey).val[1]
p[gsc.M, i][j] = DFG.refMeans(getState(fg, vId, solveKey))[1]
end
end
end
Expand Down Expand Up @@ -687,7 +687,7 @@ function solveConditionalsParametric(
flatvar = FlatVariables(fg, varIds)

for vId in varIds
p = getState(fg, vId, solvekey).val[1]
p = DFG.refMeans(getState(fg, vId, solvekey))[1]
flatvar[vId] = getCoordinates(getStateKind(fg, vId), p)
end
initValues = flatvar.X
Expand Down Expand Up @@ -825,9 +825,9 @@ function updateSolverDataParametric!(
cov::AbstractMatrix,
)
# fill in the variable node data value
vnd.val[1] = val
DFG.refMeans(vnd)[1] = val
#calculate and fill in covariance
vnd.bw .= cov
DFG.refCovariances(vnd)[1] .= cov
return vnd
end

Expand Down Expand Up @@ -888,15 +888,15 @@ function initParametricFrom!(
for v in getVariables(fg)
fromvnd = getState(v, fromkey)
dims = getDimension(v)
getState(v, parkey).val[1] = fromvnd.val[1]
getState(v, parkey).bw[1:dims, 1:dims] = LinearAlgebra.I(dims)
DFG.refMeans(getState(v, parkey))[1] = DFG.refMeans(fromvnd)[1]
DFG.refCovariances(getState(v, parkey))[1] = LinearAlgebra.I(dims)
end
else
for var in getVariables(fg)
dims = getDimension(var)
μ, Σ = calcMeanCovar(var, fromkey)
getState(var, parkey).val[1] = μ
getState(var, parkey).bw[1:dims, 1:dims] = Σ
DFG.refMeans(getState(var, parkey))[1] = μ
DFG.refCovariances(getState(var, parkey))[1] = Σ
end
end
end
Expand Down Expand Up @@ -963,10 +963,12 @@ function createMvNormal(v::VariableCompute, key = :parametric)
if key == :parametric
vnd = getState(v, :parametric)
dims = getDimension(vnd)
return createMvNormal(vnd.val[1:dims, 1], vnd.bw[1:dims, 1:dims])
val = DFG.refMeans(vnd)[1]
cov = DFG.refCovariances(vnd)[1:dims, 1:dims]
return createMvNormal(val, cov)
else
@warn "Trying MvNormal Fit"
return fit(MvNormal, getState(v, key).val)
return fit(MvNormal, DFG.refPoints(getState(v, key)))
end
end

Expand Down
66 changes: 28 additions & 38 deletions IncrementalInference/src/services/FactorGraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,15 @@ reshapeVec2Mat(vec::Vector, rows::Int) = reshape(vec, rows, round(Int, length(ve

Fetch the variable marginal joint sampled points. Use [`getBelief`](@ref) to retrieve the full Belief object.
"""
getVal(v::VariableCompute; solveKey::Symbol = :default) = v.states[solveKey].val
#FIXME replace with refPoints
getVal(v::VariableCompute; solveKey::Symbol = :default) = DFG.refPoints(getState(v, solveKey))
function getVal(v::VariableCompute, idx::Int; solveKey::Symbol = :default)
return v.states[solveKey].val[:, idx]
return DFG.refPoints(getState(v, solveKey))[idx]
end
getVal(vnd::State) = vnd.val
getVal(vnd::State, idx::Int) = vnd.val[:, idx]
getVal(vnd::State) = DFG.refPoints(vnd)
getVal(vnd::State, idx::Int) = DFG.refPoints(vnd)[:, idx]
function getVal(dfg::AbstractDFG, lbl::Symbol; solveKey::Symbol = :default)
return getVariable(dfg, lbl).states[solveKey].val
return DFG.refPoints(getVariable(dfg, lbl).states[solveKey])
end

"""
Expand All @@ -73,15 +74,15 @@ function getNumPts(v::VariableCompute; solveKey::Symbol = :default)::Int
end

function AMP.getBW(vnd::State)
return vnd.bw
return DFG.refBandwidth(vnd)
end

# setVal! assumes you will update values to database separate, this used for local graph mods only
function getBWVal(v::VariableCompute; solveKey::Symbol = :default)
return getState(v, solveKey).bw
return DFG.refBandwidth(getState(v, solveKey))
end
function setBW!(vd::State, bw::Array{Float64, 2}; solveKey::Symbol = :default)
vd.bw = bw
DFG.refBandwidth(vd) .= bw
return nothing
end
function setBW!(v::VariableCompute, bw::Array{Float64, 2}; solveKey::Symbol = :default)
Expand All @@ -90,7 +91,9 @@ function setBW!(v::VariableCompute, bw::Array{Float64, 2}; solveKey::Symbol = :d
end

function setVal!(vd::State, val::AbstractVector{P}) where {P}
vd.val = val
points = DFG.refPoints(vd)
resize!(points, length(val))
points .= val
return nothing
end
function setVal!(
Expand Down Expand Up @@ -393,20 +396,13 @@ function DefaultNodeDataParametric(
# dims, false, :_null, Symbol[], variableType, true, 0.0, false, dontmargin)
else
ϵ = getPointIdentity(variableType)
return State(solveKey, variableType;
val=[ϵ],
bw=zeros(dims, dims),
# Symbol[],
# false,
# :_null,
# Symbol[],
initialized=false,
observability=zeros(dims),
marginalized=false,
# dontmargin,
# 0,
# 0,
belief = DFG.BeliefRepresentation(
DFG.GaussianDensityKind(),
variableType;
means = [ϵ],
covariances = [zeros(dims, dims)],
)
return State(solveKey, variableType; belief)
end
end

Expand Down Expand Up @@ -471,26 +467,20 @@ function setDefaultNodeData!(
#
(val, bw)
end

belief = DFG.BeliefRepresentation(
DFG.NonparametricDensityKind(),
varType;
points = val,
bandwidth = bw,
)
# make and set the new solverData
mergeState!(
v,
State(solveKey, varType;
# id=nothing,
val,
bw,
# Symbol[],
# sp,
# dims,
# false,
# :_null,
# Symbol[],
belief,
initialized=isinit,
observability=zeros(getDimension(v)),
marginalized=false,
# dontmargin,
# 0,
# 0,

)
)
return nothing
Expand Down Expand Up @@ -569,7 +559,7 @@ addVariable!(fg, :x0, Pose2)
function DFG.addVariable!(
dfg::AbstractDFG,
label::Symbol,
statetype::Union{T, Type{T}};
statekind::Union{T, Type{T}};
tags::Union{Set{Symbol}, Vector{Symbol}} = Set{Symbol}(),
timestamp::Union{TimeDateZone, ZonedDateTime} = DFG.now_tdz(),
solvable::Int = 1,
Expand Down Expand Up @@ -600,7 +590,7 @@ function DFG.addVariable!(
tags = union(Set(tags), [:VARIABLE])
v = VariableDFG(
label,
statetype;
statekind;
tags,
bloblets,
blobentries,
Expand Down
4 changes: 2 additions & 2 deletions IncrementalInference/src/services/GraphInit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,8 @@ function initVariable!(
if solveKey == :parametric
μ, iΣ = getMeasurementParametric(samplable_belief)
vnd = getState(variable, solveKey)
vnd.val[1] = getPoint(getStateKind(variable), μ)
vnd.bw .= inv(iΣ)
DFG.refMeans(vnd)[1] = getPoint(getStateKind(variable), μ)
DFG.refCovariances(vnd)[1] .= inv(iΣ)
vnd.initialized = true
else
points = [samplePoint(M, samplable_belief) for _ = 1:N]
Expand Down
2 changes: 1 addition & 1 deletion IncrementalInference/src/services/VariableStatistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ end

#TODO consolidate
function calcMeanCovar(vari::VariableCompute, solvekey = :default)
pts = getState(vari, solvekey).val
pts = DFG.refPoints(getState(vari, solvekey))
μ = mean(getManifold(vari), pts)
Σ = cov(getStateKind(vari), pts)
return μ, Σ
Expand Down
Loading
Loading