Skip to content

Commit 6b7bfef

Browse files
authored
Merge pull request #783 from JuliaRobotics/21Q2/enh/valVectorT
towards val::Vector{T}
2 parents 49590c9 + 4441e4b commit 6b7bfef

11 files changed

+291
-123
lines changed

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1919
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2020
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
2121
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
22+
TensorCast = "02d47bb6-7ce6-556a-be16-bb1710789e2b"
2223
TimeZones = "f269a46b-ccf7-5d73-abea-4c690281aa53"
2324
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2425
Unmarshal = "cbff2730-442d-58d7-89d1-8e530c41eb02"
@@ -34,16 +35,18 @@ ManifoldsBase = "0.11, 0.12"
3435
Neo4j = "2"
3536
Pkg = "1.4, 1.5"
3637
Reexport = "0.2, 0.3, 0.4, 0.5, 1"
37-
Requires = "0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 1"
38+
Requires = "0.5, 1"
39+
TensorCast = "0.3.3, 0.4"
3840
TimeZones = "1.3.1"
3941
Unmarshal = "0.4"
4042
julia = "1.4"
4143

4244
[extras]
4345
GraphPlot = "a2cc645c-3eea-5389-862e-a155d0052231"
46+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
4447
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
4548
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
4649
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4750

4851
[targets]
49-
test = ["Test", "GraphPlot", "Manifolds", "Pkg"]
52+
test = ["Test", "GraphPlot", "LinearAlgebra", "Manifolds", "Pkg"]

src/DistributedFactorGraphs.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ using LinearAlgebra
2828
using SparseArrays
2929
using UUIDs
3030
using Pkg
31+
using TensorCast
3132

3233
# used for @defVariable
3334
import ManifoldsBase
@@ -135,7 +136,10 @@ export getSolverData
135136
export getVariableType
136137

137138
# VariableType functions
138-
export getDimension, getManifolds, getManifold, getPointType
139+
export getDimension, getManifold, getPointType
140+
export getPointIdentity, getPoint, getCoordinates
141+
142+
export getManifolds # TODO Deprecate?
139143

140144
# Small Data CRUD
141145
export SmallDataTypes, getSmallData, addSmallData!, updateSmallData!, deleteSmallData!, listSmallData, emptySmallData!

src/entities/DFGVariable.jl

Lines changed: 95 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,58 +16,92 @@ Data container for solver-specific data.
1616
Fields:
1717
$(TYPEDFIELDS)
1818
"""
19-
mutable struct VariableNodeData{T<:InferenceVariable}
20-
val::Array{Float64,2}
21-
bw::Array{Float64,2}
22-
BayesNetOutVertIDs::Array{Symbol,1}
23-
dimIDs::Array{Int,1} # Likely deprecate
19+
mutable struct VariableNodeData{T<:InferenceVariable, P}
20+
val::Vector{P}
21+
bw::Matrix{Float64}
22+
BayesNetOutVertIDs::Vector{Symbol}
23+
dimIDs::Vector{Int} # Likely deprecate
24+
2425
dims::Int
2526
eliminated::Bool
2627
BayesNetVertID::Symbol # Union{Nothing, }
27-
separator::Array{Symbol,1}
28+
separator::Vector{Symbol}
29+
2830
variableType::T
2931
initialized::Bool
3032
inferdim::Float64
3133
ismargin::Bool
34+
3235
dontmargin::Bool
3336
solveInProgress::Int
3437
solvedCount::Int
3538
solveKey::Symbol
39+
3640
events::Dict{Symbol,Threads.Condition}
37-
VariableNodeData{T}(; solveKey::Symbol=:default) where {T <:InferenceVariable} =
38-
new{T}(zeros(1,1), zeros(1,1), Symbol[], Int[], 0, false, :NOTHING, Symbol[], T(), false, 0.0, false, false, 0, 0, solveKey, Dict{Symbol,Threads.Condition}())
39-
VariableNodeData{T}(val::Array{Float64,2},
40-
bw::Array{Float64,2},
41-
BayesNetOutVertIDs::Array{Symbol,1},
42-
dimIDs::Array{Int,1},
43-
dims::Int,eliminated::Bool,
44-
BayesNetVertID::Symbol,
45-
separator::Array{Symbol,1},
46-
variableType::T,
47-
initialized::Bool,
48-
inferdim::Float64,
49-
ismargin::Bool,
50-
dontmargin::Bool,
51-
solveInProgress::Int=0,
52-
solvedCount::Int=0,
53-
solveKey::Symbol=:default,
54-
events::Dict{Symbol,Threads.Condition}=Dict{Symbol,Threads.Condition}()) where T <: InferenceVariable =
55-
new{T}(val,bw,BayesNetOutVertIDs,dimIDs,dims,
56-
eliminated,BayesNetVertID,separator,
57-
variableType::T,initialized,inferdim,ismargin,
58-
dontmargin, solveInProgress, solvedCount, solveKey, events)
41+
42+
# VariableNodeData{T,P}() = new{T,P}()
43+
VariableNodeData{T,P}(w...) where {T <:InferenceVariable, P} = new{T,P}(w...)
44+
VariableNodeData{T,P}(;solveKey::Symbol=:default ) where {T <:InferenceVariable, P} = new{T,P}(
45+
Vector{P}(),
46+
zeros(0,0),
47+
Symbol[],
48+
Int[],
49+
0,
50+
false,
51+
:NOTHING,
52+
Symbol[],
53+
T(),
54+
false,
55+
0.0,
56+
false,
57+
false,
58+
0,
59+
0,
60+
solveKey,
61+
Dict{Symbol,Threads.Condition}() )
62+
#
5963
end
6064

6165
##------------------------------------------------------------------------------
6266
## Constructors
6367

64-
VariableNodeData(val::Array{Float64,2},
65-
bw::Array{Float64,2},
66-
BayesNetOutVertIDs::Array{Symbol,1},
67-
dimIDs::Array{Int,1},
68-
dims::Int,eliminated::Bool,
68+
VariableNodeData{T}(;solveKey::Symbol=:default ) where T <: InferenceVariable = VariableNodeData{T, getPointType(T)}(solveKey=solveKey)
69+
70+
VariableNodeData( val::Vector{P},
71+
bw::Matrix{<:Real},
72+
BayesNetOutVertIDs::AbstractVector{Symbol},
73+
dimIDs::AbstractVector{Int},
74+
dims::Int,
75+
eliminated::Bool,
76+
BayesNetVertID::Symbol,
77+
separator::Array{Symbol,1},
78+
variableType::T,
79+
initialized::Bool,
80+
inferdim::Float64,
81+
ismargin::Bool,
82+
dontmargin::Bool,
83+
solveInProgress::Int=0,
84+
solvedCount::Int=0,
85+
solveKey::Symbol=:default,
86+
events::Dict{Symbol,Threads.Condition}=Dict{Symbol,Threads.Condition}()
87+
) where {T <: InferenceVariable, P} = VariableNodeData{T,P}(
88+
val,bw,BayesNetOutVertIDs,dimIDs,dims,
89+
eliminated,BayesNetVertID,separator,
90+
variableType,initialized,inferdim,ismargin,
91+
dontmargin, solveInProgress, solvedCount, solveKey, events )
92+
#
93+
94+
95+
#
96+
97+
VariableNodeData(val::Vector{P},
98+
bw::Matrix{<:Real},
99+
BayesNetOutVertIDs::AbstractVector{Symbol},
100+
dimIDs::AbstractVector{Int},
101+
dims::Int,
102+
eliminated::Bool,
69103
BayesNetVertID::Symbol,
70-
separator::Array{Symbol,1},
104+
separator::AbstractVector{Symbol},
71105
variableType::T,
72106
initialized::Bool,
73107
inferdim::Float64,
@@ -76,16 +110,26 @@ VariableNodeData(val::Array{Float64,2},
76110
solveInProgress::Int=0,
77111
solvedCount::Int=0,
78112
solveKey::Symbol=:default
79-
) where T <: InferenceVariable =
80-
VariableNodeData{T}(val,bw,BayesNetOutVertIDs,dimIDs,dims,
81-
eliminated,BayesNetVertID,separator,
82-
variableType::T,initialized,inferdim,ismargin,
83-
dontmargin, solveInProgress, solvedCount,
84-
solveKey)
85-
86-
87-
VariableNodeData(variableType::T; solveKey::Symbol=:default) where T <: InferenceVariable =
88-
VariableNodeData{T}(zeros(1,1), zeros(1,1), Symbol[], Int[], 0, false, :NOTHING, Symbol[], variableType, false, 0.0, false, false, 0, 0, solveKey)
113+
) where {T <: InferenceVariable, P} =
114+
VariableNodeData{T,P}( val,bw,BayesNetOutVertIDs,dimIDs,dims,
115+
eliminated,BayesNetVertID,separator,
116+
variableType,initialized,inferdim,ismargin,
117+
dontmargin, solveInProgress, solvedCount,
118+
solveKey )
119+
#
120+
121+
function VariableNodeData(variableType::T; solveKey::Symbol=:default) where T <: InferenceVariable
122+
#
123+
# p0 = getPointIdentity(T)
124+
P0 = Vector{getPointType(T)}()
125+
# P0[1] = p0
126+
BW = zeros(0,0)
127+
# BW[1] = zeros(getDimension(T))
128+
VariableNodeData( P0, BW, Symbol[], Int[],
129+
0, false, :NOTHING, Symbol[],
130+
variableType, false, 0.0, false,
131+
false, 0, 0, solveKey )
132+
end
89133

90134
##==============================================================================
91135
## PackedVariableNodeData.jl
@@ -237,7 +281,7 @@ struct DFGVariable{T<:InferenceVariable} <: AbstractDFGVariable
237281
ppeDict::Dict{Symbol, <: AbstractPointParametricEst}
238282
"""Dictionary of solver data. May be a subset of all solutions if a solver key was specified in the get call.
239283
Accessors: [`addVariableSolverData!`](@ref), [`updateVariableSolverData!`](@ref), and [`deleteVariableSolverData!`](@ref)"""
240-
solverDataDict::Dict{Symbol, VariableNodeData{T}}
284+
solverDataDict::Dict{Symbol, <: VariableNodeData{T}}
241285
"""Dictionary of small data associated with this variable.
242286
Accessors: [`getSmallData`](@ref), [`setSmallData!`](@ref)"""
243287
smallData::Dict{Symbol, SmallDataTypes}
@@ -261,10 +305,10 @@ function DFGVariable(label::Symbol, variableType::T;
261305
nstime::Nanosecond = Nanosecond(0),
262306
tags::Set{Symbol}=Set{Symbol}(),
263307
estimateDict::Dict{Symbol, <: AbstractPointParametricEst}=Dict{Symbol, MeanMaxPPE}(),
264-
solverDataDict::Dict{Symbol, VariableNodeData{T}}=Dict{Symbol, VariableNodeData{T}}(),
308+
solverDataDict::Dict{Symbol, VariableNodeData{T,P}}=Dict{Symbol, VariableNodeData{T,getPointType(T)}}(),
265309
smallData::Dict{Symbol, SmallDataTypes}=Dict{Symbol, SmallDataTypes}(),
266310
dataDict::Dict{Symbol, AbstractDataEntry}=Dict{Symbol,AbstractDataEntry}(),
267-
solvable::Int=1) where {T <: InferenceVariable}
311+
solvable::Int=1) where {T <: InferenceVariable, P}
268312

269313
if timestamp isa DateTime
270314
DFGVariable{T}(label, ZonedDateTime(timestamp, localzone()), nstime, tags, estimateDict, solverDataDict, smallData, dataDict, Ref(solvable))
@@ -280,12 +324,13 @@ function DFGVariable(label::Symbol,
280324
tags::Set{Symbol}=Set{Symbol}(),
281325
estimateDict::Dict{Symbol, <: AbstractPointParametricEst}=Dict{Symbol, MeanMaxPPE}(),
282326
smallData::Dict{Symbol, SmallDataTypes}=Dict{Symbol, SmallDataTypes}(),
283-
dataDict::Dict{Symbol, AbstractDataEntry}=Dict{Symbol,AbstractDataEntry}(),
327+
dataDict::Dict{Symbol, <: AbstractDataEntry}=Dict{Symbol,AbstractDataEntry}(),
284328
solvable::Int=1) where {T <: InferenceVariable}
329+
#
285330
if timestamp isa DateTime
286-
DFGVariable{T}(label, ZonedDateTime(timestamp, localzone()), nstime, tags, estimateDict, Dict{Symbol, VariableNodeData{T}}(:default=>solverData), smallData, dataDict, Ref(solvable))
331+
DFGVariable{T}(label, ZonedDateTime(timestamp, localzone()), nstime, tags, estimateDict, Dict{Symbol, VariableNodeData{T, getPointType(T)}}(:default=>solverData), smallData, dataDict, Ref(solvable))
287332
else
288-
DFGVariable{T}(label, timestamp, nstime, tags, estimateDict, Dict{Symbol, VariableNodeData{T}}(:default=>solverData), smallData, dataDict, Ref(solvable))
333+
DFGVariable{T}(label, timestamp, nstime, tags, estimateDict, Dict{Symbol, VariableNodeData{T, getPointType(T)}}(:default=>solverData), smallData, dataDict, Ref(solvable))
289334
end
290335
end
291336

src/services/CompareUtils.jl

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,19 @@ function compareFactor(A::DFGFactor,
239239
TP = compareAll(A, B, skip=union([:attributes;:solverData;:_variableOrderSymbols],skip), show=show)
240240
# TP = TP & compareAll(A.attributes, B.attributes, skip=[:data;], show=show)
241241
TP = TP & compareAllSpecial(getSolverData(A), getSolverData(B), skip=union([:fnc], skip), show=show)
242+
if :fnc in skip
243+
return TP
244+
end
242245
TP = TP & compareAllSpecial(getSolverData(A).fnc, getSolverData(B).fnc, skip=union([:cpt;:measurement;:params;:varidx;:threadmodel], skip), show=show)
246+
if !(:measurement in skip)
243247
TP = TP & (skipsamples || compareAll(getSolverData(A).fnc.measurement, getSolverData(B).fnc.measurement, show=show, skip=skip))
244-
TP = TP & (skipcompute || compareAll(getSolverData(A).fnc.params, getSolverData(B).fnc.params, show=show, skip=skip))
245-
TP = TP & (skipcompute || compareAll(getSolverData(A).fnc.varidx, getSolverData(B).fnc.varidx, show=show, skip=skip))
248+
end
249+
if !(:params in skip)
250+
TP = TP & (skipcompute || compareAll(getSolverData(A).fnc.params, getSolverData(B).fnc.params, show=show, skip=skip))
251+
end
252+
if !(:varidx in skip)
253+
TP = TP & (skipcompute || compareAll(getSolverData(A).fnc.varidx, getSolverData(B).fnc.varidx, show=show, skip=skip))
254+
end
246255

247256
return TP
248257
end
@@ -339,11 +348,12 @@ Related:
339348
340349
`compareFactorGraphs`, `compareSimilarVariables`, `compareAllVariables`, `ls`.
341350
"""
342-
function compareSimilarFactors(fgA::G1,
343-
fgB::G2;
344-
skipsamples::Bool=true,
345-
skipcompute::Bool=true,
346-
show::Bool=true )::Bool where {G1 <: AbstractDFG, G2 <: AbstractDFG}
351+
function compareSimilarFactors( fgA::G1,
352+
fgB::G2;
353+
skipsamples::Bool=true,
354+
skipcompute::Bool=true,
355+
skip::AbstractVector{Symbol}=Symbol[],
356+
show::Bool=true ) where {G1 <: AbstractDFG, G2 <: AbstractDFG}
347357
#
348358
xlA = listFactors(fgA)
349359
xlB = listFactors(fgB)
@@ -354,7 +364,8 @@ function compareSimilarFactors(fgA::G1,
354364

355365
# compare the common set
356366
for var in xlAB
357-
TP = TP && compareFactor(getFactor(fgA, var), getFactor(fgB, var), skipsamples=skipsamples, skipcompute=skipcompute, show=show)
367+
TP = TP && compareFactor( getFactor(fgA, var), getFactor(fgB, var),
368+
skipsamples=skipsamples, skipcompute=skipcompute, skip=skip, show=show)
358369
end
359370

360371
# return comparison result
@@ -376,12 +387,12 @@ Related:
376387
377388
`compareSimilarVariables`, `compareSimilarFactors`, `compareAllVariables`, `ls`.
378389
"""
379-
function compareFactorGraphs(fgA::G1,
380-
fgB::G2;
381-
skipsamples::Bool=true,
382-
skipcompute::Bool=true,
383-
skip::Vector{Symbol}=Symbol[],
384-
show::Bool=true )::Bool where {G1 <: AbstractDFG, G2 <: AbstractDFG}
390+
function compareFactorGraphs( fgA::G1,
391+
fgB::G2;
392+
skipsamples::Bool=true,
393+
skipcompute::Bool=true,
394+
skip::Vector{Symbol}=Symbol[],
395+
show::Bool=true ) where {G1 <: AbstractDFG, G2 <: AbstractDFG}
385396
#
386397
skiplist = Symbol[:g;:bn;:IDs;:fIDs;:id;:nodeIDs;:factorIDs;:fifo;:solverParams; :factorOperationalMemoryType]
387398
skiplist = union(skiplist, skip)

src/services/CustomPrinting.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ function printVariable( io::IO, vert::DFGVariable;
5353
println(ioc, " initilized: ", isInitialized(vert, :default))
5454
println(ioc, " marginalized: ", isMarginalized(vert, :default))
5555
println(ioc, " size bel. samples: ", size(vnd.val))
56-
println(ioc, " kde bandwidths: ", round.((vnd.bw)[:,1], digits=4))
56+
print(ioc, " kde bandwidths: ")
57+
0 < length(vnd.bw) ? println(ioc, round.(vnd.bw[1], digits=4)) : nothing
5758
printstyled(ioc, " VNDs: ",bold=true)
5859
println(ioc, solk[smsk], 4<lsolk ? "..." : "")
5960
printstyled(ioc, " # PPE solveKeys= ($(length(getPPEDict(vert))))", bold=true)

0 commit comments

Comments
 (0)