Skip to content

Commit 6e253d0

Browse files
committed
DFG tests run through
1 parent 86836f3 commit 6e253d0

File tree

6 files changed

+123
-55
lines changed

6 files changed

+123
-55
lines changed

src/entities/DFGVariable.jl

Lines changed: 66 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,47 +21,77 @@ mutable struct VariableNodeData{T<:InferenceVariable, P}
2121
bw::Vector{Vector{Float64}}
2222
BayesNetOutVertIDs::Vector{Symbol}
2323
dimIDs::Vector{Int} # Likely deprecate
24+
2425
dims::Int
2526
eliminated::Bool
2627
BayesNetVertID::Symbol # Union{Nothing, }
2728
separator::Vector{Symbol}
29+
2830
variableType::T
2931
initialized::Bool
3032
inferdim::Float64
3133
ismargin::Bool
34+
3235
dontmargin::Bool
3336
solveInProgress::Int
3437
solvedCount::Int
3538
solveKey::Symbol
39+
3640
events::Dict{Symbol,Threads.Condition}
37-
VariableNodeData{T,P}(; solveKey::Symbol=:default) where {T <:InferenceVariable, P} =
38-
new{T,getPointType(T)}(Vector{getPointType(T)}(undef, 1), Vector{Vector{Float64}}(undef,1), Symbol[], Int[], 0, false, :NOTHING, Symbol[], T(), false, 0.0, false, false, 0, 0, solveKey, Dict{Symbol,Threads.Condition}())
39-
VariableNodeData{T,P}(val::Vector{P},
40-
bw::Vector{Vector{Float64}},
41-
BayesNetOutVertIDs::AbstractVector{Symbol},
42-
dimIDs::AbstractVector{Int},
43-
dims::Int,
44-
eliminated::Bool,
45-
BayesNetVertID::Symbol,
46-
separator::Array{Symbol,1},
47-
variableType::T,
48-
initialized::Bool,
49-
inferdim::Float64,
50-
ismargin::Bool,
51-
dontmargin::Bool,
52-
solveInProgress::Int=0,
53-
solvedCount::Int=0,
54-
solveKey::Symbol=:default,
55-
events::Dict{Symbol,Threads.Condition}=Dict{Symbol,Threads.Condition}()) where {T <: InferenceVariable, P} =
56-
new{T,P}( val,bw,BayesNetOutVertIDs,dimIDs,dims,
57-
eliminated,BayesNetVertID,separator,
58-
variableType,initialized,inferdim,ismargin,
59-
dontmargin, solveInProgress, solvedCount, solveKey, events)
41+
42+
# VariableNodeData{T,P}() = new{T,P}()
43+
VariableNodeData{T,P}(w...) where {T <:InferenceVariable, P} = new{T,P}(w...)
44+
VariableNodeData{T,P}(;solveKey::Symbol=:default ) where {T <:InferenceVariable, P} = new{T,P}(
45+
Vector{P}(),
46+
Vector{Vector{Float64}}(),
47+
Symbol[],
48+
Int[],
49+
0,
50+
false,
51+
:NOTHING,
52+
Symbol[],
53+
T(),
54+
false,
55+
0.0,
56+
false,
57+
false,
58+
0,
59+
0,
60+
solveKey,
61+
Dict{Symbol,Threads.Condition}() )
62+
#
6063
end
6164

6265
##------------------------------------------------------------------------------
6366
## Constructors
6467

68+
VariableNodeData{T}(;solveKey::Symbol=:default ) where T <: InferenceVariable = VariableNodeData{T, getPointType(T)}(solveKey=solveKey)
69+
70+
VariableNodeData( val::Vector{P},
71+
bw::Vector{Vector{Float64}},
72+
BayesNetOutVertIDs::AbstractVector{Symbol},
73+
dimIDs::AbstractVector{Int},
74+
dims::Int,
75+
eliminated::Bool,
76+
BayesNetVertID::Symbol,
77+
separator::Array{Symbol,1},
78+
variableType::T,
79+
initialized::Bool,
80+
inferdim::Float64,
81+
ismargin::Bool,
82+
dontmargin::Bool,
83+
solveInProgress::Int=0,
84+
solvedCount::Int=0,
85+
solveKey::Symbol=:default,
86+
events::Dict{Symbol,Threads.Condition}=Dict{Symbol,Threads.Condition}()
87+
) where {T <: InferenceVariable, P} = VariableNodeData{T,P}(
88+
val,bw,BayesNetOutVertIDs,dimIDs,dims,
89+
eliminated,BayesNetVertID,separator,
90+
variableType,initialized,inferdim,ismargin,
91+
dontmargin, solveInProgress, solvedCount, solveKey, events )
92+
#
93+
94+
6595
#
6696

6797
VariableNodeData(val::Vector{P},
@@ -81,7 +111,7 @@ VariableNodeData(val::Vector{P},
81111
solvedCount::Int=0,
82112
solveKey::Symbol=:default
83113
) where {T <: InferenceVariable, P} =
84-
VariableNodeData{T}( val,bw,BayesNetOutVertIDs,dimIDs,dims,
114+
VariableNodeData{T,P}( val,bw,BayesNetOutVertIDs,dimIDs,dims,
85115
eliminated,BayesNetVertID,separator,
86116
variableType,initialized,inferdim,ismargin,
87117
dontmargin, solveInProgress, solvedCount,
@@ -90,11 +120,11 @@ VariableNodeData(val::Vector{P},
90120

91121
function VariableNodeData(variableType::T; solveKey::Symbol=:default) where T <: InferenceVariable
92122
#
93-
p0 = getPointIdentity(T)
94-
P0 = Vector{getPointType(T)}(undef, 1)
95-
P0[1] = p0
96-
BW = Vector{Vector{Float64}}(undef,1)
97-
BW[1] = zeros(getDimension(T))
123+
# p0 = getPointIdentity(T)
124+
P0 = Vector{getPointType(T)}()
125+
# P0[1] = p0
126+
BW = Vector{Vector{Float64}}()
127+
# BW[1] = zeros(getDimension(T))
98128
VariableNodeData( P0, BW, Symbol[], Int[],
99129
0, false, :NOTHING, Symbol[],
100130
variableType, false, 0.0, false,
@@ -251,7 +281,7 @@ struct DFGVariable{T<:InferenceVariable} <: AbstractDFGVariable
251281
ppeDict::Dict{Symbol, <: AbstractPointParametricEst}
252282
"""Dictionary of solver data. May be a subset of all solutions if a solver key was specified in the get call.
253283
Accessors: [`addVariableSolverData!`](@ref), [`updateVariableSolverData!`](@ref), and [`deleteVariableSolverData!`](@ref)"""
254-
solverDataDict::Dict{Symbol, VariableNodeData{T}}
284+
solverDataDict::Dict{Symbol, <: VariableNodeData{T}}
255285
"""Dictionary of small data associated with this variable.
256286
Accessors: [`getSmallData`](@ref), [`setSmallData!`](@ref)"""
257287
smallData::Dict{Symbol, SmallDataTypes}
@@ -275,10 +305,10 @@ function DFGVariable(label::Symbol, variableType::T;
275305
nstime::Nanosecond = Nanosecond(0),
276306
tags::Set{Symbol}=Set{Symbol}(),
277307
estimateDict::Dict{Symbol, <: AbstractPointParametricEst}=Dict{Symbol, MeanMaxPPE}(),
278-
solverDataDict::Dict{Symbol, VariableNodeData{T}}=Dict{Symbol, VariableNodeData{T}}(),
308+
solverDataDict::Dict{Symbol, VariableNodeData{T,P}}=Dict{Symbol, VariableNodeData{T,getPointType(T)}}(),
279309
smallData::Dict{Symbol, SmallDataTypes}=Dict{Symbol, SmallDataTypes}(),
280310
dataDict::Dict{Symbol, AbstractDataEntry}=Dict{Symbol,AbstractDataEntry}(),
281-
solvable::Int=1) where {T <: InferenceVariable}
311+
solvable::Int=1) where {T <: InferenceVariable, P}
282312

283313
if timestamp isa DateTime
284314
DFGVariable{T}(label, ZonedDateTime(timestamp, localzone()), nstime, tags, estimateDict, solverDataDict, smallData, dataDict, Ref(solvable))
@@ -294,12 +324,13 @@ function DFGVariable(label::Symbol,
294324
tags::Set{Symbol}=Set{Symbol}(),
295325
estimateDict::Dict{Symbol, <: AbstractPointParametricEst}=Dict{Symbol, MeanMaxPPE}(),
296326
smallData::Dict{Symbol, SmallDataTypes}=Dict{Symbol, SmallDataTypes}(),
297-
dataDict::Dict{Symbol, AbstractDataEntry}=Dict{Symbol,AbstractDataEntry}(),
327+
dataDict::Dict{Symbol, <: AbstractDataEntry}=Dict{Symbol,AbstractDataEntry}(),
298328
solvable::Int=1) where {T <: InferenceVariable}
329+
#
299330
if timestamp isa DateTime
300-
DFGVariable{T}(label, ZonedDateTime(timestamp, localzone()), nstime, tags, estimateDict, Dict{Symbol, VariableNodeData{T}}(:default=>solverData), smallData, dataDict, Ref(solvable))
331+
DFGVariable{T}(label, ZonedDateTime(timestamp, localzone()), nstime, tags, estimateDict, Dict{Symbol, VariableNodeData{T, getPointType(T)}}(:default=>solverData), smallData, dataDict, Ref(solvable))
301332
else
302-
DFGVariable{T}(label, timestamp, nstime, tags, estimateDict, Dict{Symbol, VariableNodeData{T}}(:default=>solverData), smallData, dataDict, Ref(solvable))
333+
DFGVariable{T}(label, timestamp, nstime, tags, estimateDict, Dict{Symbol, VariableNodeData{T, getPointType(T)}}(:default=>solverData), smallData, dataDict, Ref(solvable))
303334
end
304335
end
305336

src/services/CustomPrinting.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ function printVariable( io::IO, vert::DFGVariable;
5353
println(ioc, " initilized: ", isInitialized(vert, :default))
5454
println(ioc, " marginalized: ", isMarginalized(vert, :default))
5555
println(ioc, " size bel. samples: ", size(vnd.val))
56-
println(ioc, " kde bandwidths: ", round.((vnd.bw)[:,1], digits=4))
56+
print(ioc, " kde bandwidths: ")
57+
0 < length(vnd.bw) ? println(ioc, round.(vnd.bw[1], digits=4)) : nothing
5758
printstyled(ioc, " VNDs: ",bold=true)
5859
println(ioc, solk[smsk], 4<lsolk ? "..." : "")
5960
printstyled(ioc, " # PPE solveKeys= ($(length(getPPEDict(vert))))", bold=true)

src/services/Serialization.jl

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ end
109109
##==============================================================================
110110
## Variable Packing and unpacking
111111
##==============================================================================
112-
function packVariable(dfg::G, v::DFGVariable) where G <: AbstractDFG
112+
function packVariable(dfg::AbstractDFG, v::DFGVariable)
113113
props = Dict{String, Any}()
114114
props["label"] = string(v.label)
115115
props["timestamp"] = Dates.format(v.timestamp, "yyyy-mm-ddTHH:MM:SS.ssszzz")
@@ -213,10 +213,21 @@ end
213213
# returns a PackedVariableNodeData
214214
function packVariableNodeData(::G, d::VariableNodeData{T}) where {G <: AbstractDFG, T <: InferenceVariable}
215215
@debug "Dispatching conversion variable -> packed variable for type $(string(d.variableType))"
216-
precast = getCoordinates.(T, d.val)
217-
@cast castval[i,j] := precast[j][i]
218-
_val = precast[:]
219-
@cast castbw[i,j] := d.bw[j][i]
216+
# @show d.val
217+
castval = if 0 < length(d.val)
218+
precast = getCoordinates.(T, d.val)
219+
@cast castval[i,j] := precast[j][i]
220+
castval
221+
else
222+
zeros(1,0)
223+
end
224+
_val = castval[:]
225+
castbw = if 0 < length(d.bw)
226+
@cast castbw[i,j] := d.bw[j][i]
227+
castbw
228+
else
229+
zeros(1,0)
230+
end
220231
_bw = castbw[:]
221232
return PackedVariableNodeData(_val, size(castval,1),
222233
_bw, size(castbw,1),
@@ -245,16 +256,26 @@ function unpackVariableNodeData(dfg::G, d::PackedVariableNodeData) where G <: Ab
245256
c3 = r3 > 0 ? floor(Int,length(d.vecval)/r3) : 0
246257
M3 = reshape(d.vecval,r3,c3)
247258
@cast val_[j][i] := M3[i,j]
248-
vals = getPoint.(T, val_)
259+
vals = Vector{getPointType(T)}(undef, length(val_))
260+
# vals = getPoint.(T, val_)
261+
for v in val_
262+
vals[i] = getPoint(T, v)
263+
end
249264

250265
r4 = d.dimbw
251266
c4 = r4 > 0 ? floor(Int,length(d.vecbw)/r4) : 0
252267
M4 = reshape(d.vecbw,r4,c4)
253-
@cast bw[j][i] := M4[i,j]
254-
255-
return VariableNodeData{T}(vals, bw, d.BayesNetOutVertIDs,
268+
bw = Vector{Vector{Float64}}(undef,size(M4,2))
269+
for j in 1:size(M4,2)
270+
bw[j] = collect(M4[:,j])
271+
end
272+
273+
#
274+
return VariableNodeData{T, getPointType(T)}(vals, bw, d.BayesNetOutVertIDs,
256275
d.dimIDs, d.dims, d.eliminated, d.BayesNetVertID, d.separator,
257-
st(), d.initialized, d.inferdim, d.ismargin, d.dontmargin, d.solveInProgress, d.solvedCount, d.solveKey)
276+
T(), d.initialized, d.inferdim, d.ismargin, d.dontmargin,
277+
d.solveInProgress, d.solvedCount, d.solveKey,
278+
Dict{Symbol,Threads.Condition}() )
258279
end
259280

260281
##==============================================================================

test/compareTests.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ vnd2 = deepcopy(vnd1)
1313
vnd3 = VariableNodeData(TestVariableType2())
1414

1515
@test vnd1 == vnd2
16-
vnd1.val[1] = [1.0;]
17-
vnd2.val = Vector{Vector{Float64}}(undef,1)
18-
vnd2.val[1] = [1.0;]
16+
push!(vnd1.val, [1.0;])
17+
# vnd2.val = Vector{Vector{Float64}}(undef,1)
18+
push!(vnd2.val, [1.0;])
1919
@test vnd1 == vnd2
2020
vnd2.val[1] = [0.1;]
2121
@test !(vnd1 == vnd2)
@@ -68,8 +68,8 @@ vnd3 = VariableNodeData(TestVariableType2())
6868
@test !compare(vnd1, vnd3)
6969

7070
@test compare(vnd1, vnd2)
71-
vnd1.val[1][1] = 1.0
72-
vnd2.val[1][1] = 1.0
71+
push!(vnd1.val, [1.0;])
72+
push!(vnd2.val, [1.0;])
7373
@test compare(vnd1, vnd2)
7474
vnd2.val[1][1] = 0.1
7575
@test !compare(vnd1, vnd2)

test/plottingTest.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ using GraphPlot
22
using DistributedFactorGraphs
33
# using DistributedFactorGraphs.DFGPlots
44
using Test
5+
using Manifolds
56

6-
struct TestInferenceVariable1 <: InferenceVariable end
7+
# struct TestInferenceVariable1 <: InferenceVariable end
8+
@defVariable TestInferenceVariable1 Euclidean(1) [0.0;]
79

810
# Now make a complex graph for connectivity tests
911
numNodes = 10

test/testBlocks.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,14 @@ function DFGVariableSCA()
255255

256256
vorphan = DFGVariable(:orphan, TestVariableType1(), tags=v1_tags, solvable=0, solverDataDict=Dict(:default=>VariableNodeData{TestVariableType1}()))
257257

258+
# v1.solverDataDict[:default].val[1] = [0.0;]
259+
# v1.solverDataDict[:default].bw[1] = [1.0;]
260+
# v2.solverDataDict[:default].val[1] = [0.0;0.0]
261+
# v2.solverDataDict[:default].bw[1] = [1.0;1.0]
262+
# v3.solverDataDict[:default].val[1] = [0.0;0.0]
263+
# v3.solverDataDict[:default].bw[1] = [1.0;1.0]
264+
265+
258266
getSolverData(v1).solveInProgress = 1
259267

260268
@test getLabel(v1) == v1_lbl
@@ -702,6 +710,8 @@ function VSDTestBlock!(fg, v1)
702710
# - `getSolveInProgress`
703711

704712
vnd = VariableNodeData{TestVariableType1}(solveKey=:parametric)
713+
# vnd.val[1] = [0.0;]
714+
# vnd.bw[1] = [1.0;]
705715
@test addVariableSolverData!(fg, :a, vnd) == vnd
706716

707717
@test_throws ErrorException addVariableSolverData!(fg, :a, vnd)
@@ -731,7 +741,7 @@ function VSDTestBlock!(fg, v1)
731741
retVnd = updateVariableSolverData!(fg, :a, altVnd, false, [:inferdim;])
732742
@test retVnd == altVnd
733743

734-
fill!(altVnd.bw, -1.0)
744+
push!(altVnd.bw, [-1.0;])
735745
retVnd = updateVariableSolverData!(fg, :a, altVnd, false, [:bw;])
736746
@test retVnd == altVnd
737747

@@ -756,6 +766,9 @@ function VSDTestBlock!(fg, v1)
756766
# Could also do VariableNodeData(ContinuousScalar())
757767

758768
vnd = VariableNodeData{TestVariableType1}(solveKey=:parametric)
769+
# vnd.val[1] = [0.0;]
770+
# vnd.bw[1] = [1.0;]
771+
759772
addVariableSolverData!(fg, :a, vnd)
760773
@test setdiff(listVariableSolverData(fg, :a), [:default, :parametric]) == []
761774
# Get the data back - note that this is a reference to above.
@@ -1539,7 +1552,7 @@ function FileDFGTestBlock(testDFGAPI; kwargs...)
15391552
# set everything
15401553
vnd.BayesNetVertID = :outid
15411554
push!(vnd.BayesNetOutVertIDs, :id)
1542-
vnd.bw[1] = 1.0
1555+
# vnd.bw[1] = [1.0;]
15431556
push!(vnd.dimIDs, 1)
15441557
vnd.dims = 1
15451558
vnd.dontmargin = true
@@ -1550,7 +1563,7 @@ function FileDFGTestBlock(testDFGAPI; kwargs...)
15501563
push!(vnd.separator, :sep)
15511564
vnd.solveInProgress = 1
15521565
vnd.solvedCount = 2
1553-
vnd.val .= 2.0
1566+
# vnd.val[1] = [2.0;]
15541567
#update
15551568
updateVariable!(dfg, v4)
15561569

0 commit comments

Comments
 (0)