Skip to content

Commit 3ac6f34

Browse files
committed
towards val::Vector{T}
1 parent 96ad293 commit 3ac6f34

File tree

5 files changed

+131
-64
lines changed

5 files changed

+131
-64
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1919
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2020
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
2121
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
22+
TensorCast = "02d47bb6-7ce6-556a-be16-bb1710789e2b"
2223
TimeZones = "f269a46b-ccf7-5d73-abea-4c690281aa53"
2324
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2425
Unmarshal = "cbff2730-442d-58d7-89d1-8e530c41eb02"
@@ -35,6 +36,7 @@ Neo4j = "2"
3536
Pkg = "1.4, 1.5"
3637
Reexport = "0.2, 0.3, 0.4, 0.5, 1"
3738
Requires = "0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 1"
39+
TensorCast = "0.3.3, 0.4"
3840
TimeZones = "1.3.1"
3941
Unmarshal = "0.4"
4042
julia = "1.4"

src/DistributedFactorGraphs.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ using LinearAlgebra
2828
using SparseArrays
2929
using UUIDs
3030
using Pkg
31+
using TensorCast
3132

3233
# used for @defVariable
3334
import ManifoldsBase
@@ -135,7 +136,10 @@ export getSolverData
135136
export getVariableType
136137

137138
# VariableType functions
138-
export getDimension, getManifolds, getManifold, getPointType
139+
export getDimension, getManifold, getPointType
140+
export getPointIdentity, getPoint, getCoordinates
141+
142+
export getManifolds # TODO Deprecate?
139143

140144
# Small Data CRUD
141145
export SmallDataTypes, getSmallData, addSmallData!, updateSmallData!, deleteSmallData!, listSmallData, emptySmallData!

src/entities/DFGVariable.jl

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ Data container for solver-specific data.
1616
Fields:
1717
$(TYPEDFIELDS)
1818
"""
19-
mutable struct VariableNodeData{T<:InferenceVariable}
20-
val::Array{Float64,2}
21-
bw::Array{Float64,2}
19+
mutable struct VariableNodeData{T<:InferenceVariable, P}
20+
val::Vector{P}
21+
bw::Vector{Vector{Float64}}
2222
BayesNetOutVertIDs::Array{Symbol,1}
2323
dimIDs::Array{Int,1} # Likely deprecate
2424
dims::Int
@@ -35,9 +35,9 @@ mutable struct VariableNodeData{T<:InferenceVariable}
3535
solveKey::Symbol
3636
events::Dict{Symbol,Threads.Condition}
3737
VariableNodeData{T}(; solveKey::Symbol=:default) where {T <:InferenceVariable} =
38-
new{T}(zeros(1,1), zeros(1,1), Symbol[], Int[], 0, false, :NOTHING, Symbol[], T(), false, 0.0, false, false, 0, 0, solveKey, Dict{Symbol,Threads.Condition}())
39-
VariableNodeData{T}(val::Array{Float64,2},
40-
bw::Array{Float64,2},
38+
new{T,Vector{Vector{Float64}}}([[0.0;];], [[0.0;];], Symbol[], Int[], 0, false, :NOTHING, Symbol[], T(), false, 0.0, false, false, 0, 0, solveKey, Dict{Symbol,Threads.Condition}())
39+
VariableNodeData{T}(val::Vector{P},
40+
bw::Vector{Vector{Float64}},
4141
BayesNetOutVertIDs::Array{Symbol,1},
4242
dimIDs::Array{Int,1},
4343
dims::Int,eliminated::Bool,
@@ -51,18 +51,18 @@ mutable struct VariableNodeData{T<:InferenceVariable}
5151
solveInProgress::Int=0,
5252
solvedCount::Int=0,
5353
solveKey::Symbol=:default,
54-
events::Dict{Symbol,Threads.Condition}=Dict{Symbol,Threads.Condition}()) where T <: InferenceVariable =
55-
new{T}(val,bw,BayesNetOutVertIDs,dimIDs,dims,
56-
eliminated,BayesNetVertID,separator,
57-
variableType::T,initialized,inferdim,ismargin,
58-
dontmargin, solveInProgress, solvedCount, solveKey, events)
54+
events::Dict{Symbol,Threads.Condition}=Dict{Symbol,Threads.Condition}()) where {T <: InferenceVariable, P, B} =
55+
new{T,P}( val,bw,BayesNetOutVertIDs,dimIDs,dims,
56+
eliminated,BayesNetVertID,separator,
57+
variableType,initialized,inferdim,ismargin,
58+
dontmargin, solveInProgress, solvedCount, solveKey, events)
5959
end
6060

6161
##------------------------------------------------------------------------------
6262
## Constructors
6363

64-
VariableNodeData(val::Array{Float64,2},
65-
bw::Array{Float64,2},
64+
VariableNodeData(val::Vector{P},
65+
bw::Vector{Vector{Float64}},
6666
BayesNetOutVertIDs::Array{Symbol,1},
6767
dimIDs::Array{Int,1},
6868
dims::Int,eliminated::Bool,
@@ -76,16 +76,16 @@ VariableNodeData(val::Array{Float64,2},
7676
solveInProgress::Int=0,
7777
solvedCount::Int=0,
7878
solveKey::Symbol=:default
79-
) where T <: InferenceVariable =
80-
VariableNodeData{T}(val,bw,BayesNetOutVertIDs,dimIDs,dims,
81-
eliminated,BayesNetVertID,separator,
82-
variableType::T,initialized,inferdim,ismargin,
83-
dontmargin, solveInProgress, solvedCount,
84-
solveKey)
79+
) where {T <: InferenceVariable, P} =
80+
VariableNodeData{T,P}( val,bw,BayesNetOutVertIDs,dimIDs,dims,
81+
eliminated,BayesNetVertID,separator,
82+
variableType,initialized,inferdim,ismargin,
83+
dontmargin, solveInProgress, solvedCount,
84+
solveKey)
8585

8686

8787
VariableNodeData(variableType::T; solveKey::Symbol=:default) where T <: InferenceVariable =
88-
VariableNodeData{T}(zeros(1,1), zeros(1,1), Symbol[], Int[], 0, false, :NOTHING, Symbol[], variableType, false, 0.0, false, false, 0, 0, solveKey)
88+
VariableNodeData{T,Vector{getPointType(T)}}([[0.0;];], [[0.0;];], Symbol[], Int[], 0, false, :NOTHING, Symbol[], variableType, false, 0.0, false, false, 0, 0, solveKey)
8989

9090
##==============================================================================
9191
## PackedVariableNodeData.jl

src/services/DFGVariable.jl

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ getLastUpdatedTimestamp(est::AbstractPointParametricEst) = est.lastUpdatedTimest
3434
## variableType
3535
##------------------------------------------------------------------------------
3636
"""
37-
$(SIGNATURES)
37+
$(SIGNATURES)
3838
3939
Variable nodes `variableType` information holding a variety of meta data associated with the type of variable stored in that node of the factor graph.
4040
@@ -48,8 +48,8 @@ getVariableType
4848
getVariableType(v::DFGVariable{T}) where T <: InferenceVariable = T()
4949

5050
function getVariableType(vnd::VariableNodeData)
51-
# @warn "getVariableType(::VariableNodeData) is being deprecated, use getVariableType(::DFGVariable) instead."
52-
return vnd.variableType
51+
# @warn "getVariableType(::VariableNodeData) is being deprecated, use getVariableType(::DFGVariable) instead."
52+
return vnd.variableType
5353
end
5454

5555

@@ -83,10 +83,10 @@ See documentation in [Manifolds.jl on making your own](https://juliamanifolds.gi
8383
8484
Example:
8585
```
86-
DFG.@defVariable Pose2 SpecialEuclidean(2) ProductRepr{Tuple{Vector{Float64}, Matrix{Float64}}}
86+
DFG.@defVariable Pose2 SpecialEuclidean(2) ProductRepr([0;0.0],[1 0; 0 1.0])
8787
```
8888
"""
89-
macro defVariable(structname, manifold, point_type)
89+
macro defVariable(structname, manifold, point_identity)
9090
return esc(quote
9191
Base.@__doc__ struct $structname <: InferenceVariable end
9292

@@ -95,7 +95,9 @@ macro defVariable(structname, manifold, point_type)
9595

9696
DFG.getManifold(::Type{$structname}) = $manifold
9797

98-
DFG.getPointType(::Type{$structname}) = $point_type
98+
DFG.getPointType(::Type{$structname}) = typeof($point_identity)
99+
100+
DFG.getPointIdentity(::Type{$structname}) = $point_identity
99101

100102
end)
101103
end
@@ -129,6 +131,48 @@ Interface function to return the manifold point type of an InferenceVariable, ex
129131
function getPointType end
130132
getPointType(::T) where {T <: InferenceVariable} = getPointType(T)
131133

134+
"""
135+
$SIGNATURES
136+
Interface function to return the user provided identity point for this InferenceVariable manifold, extend this function for all Types<:InferenceVariable.
137+
"""
138+
function getPointIdentity end
139+
getPointIdentity(::T) where {T <: InferenceVariable} = getPointIdentity(T)
140+
141+
142+
"""
143+
$SIGNATURES
144+
145+
Default escalzation from coordinates to a group representation point. Override if defaults are not correct.
146+
E.g. coords -> se(2) -> SE(2).
147+
148+
Related
149+
150+
[`getCoordinates`](@ref)
151+
"""
152+
function getPoint(::Type{T}, v::AbstractVector) where {T <: InferenceVariable}
153+
M = getManifold(T)
154+
p0 = getPointIdentity(T)
155+
X = get_vector(M, p0, v, DefaultOrthonormalBasis())
156+
exp(M, p0, X)
157+
end
158+
159+
"""
160+
$SIGNATURES
161+
162+
Default reduction of a variable point value (a group element) into coordinates as `Vector`. Override if defaults are not correct.
163+
164+
Related
165+
166+
[`getPoint`](@ref)
167+
"""
168+
function getCoordinates(::Type{T}, p) where {T <: InferenceVariable}
169+
M = getManifold(T)
170+
p0 = getPointIdentity(T)
171+
X = log(M, p0, p)
172+
get_coordinates(M, p0, X, DefaultOrthonormalBasis())
173+
end
174+
175+
132176
##------------------------------------------------------------------------------
133177
## solvedCount
134178
##------------------------------------------------------------------------------
@@ -480,7 +524,7 @@ end
480524
$SIGNATURES
481525
Retrieve the soft type name symbol for a DFGVariableSummary. ie :Point2, Pose2, etc.
482526
"""
483-
getVariableTypeName(v::DFGVariableSummary)::Symbol = v.variableTypeName
527+
getVariableTypeName(v::DFGVariableSummary) = v.variableTypeName::Symbol
484528

485529

486530
function getVariableType(v::DFGVariableSummary)::InferenceVariable
@@ -507,7 +551,7 @@ end
507551
$(SIGNATURES)
508552
Get variable solverdata for a given solve key.
509553
"""
510-
function getVariableSolverData(dfg::AbstractDFG, variablekey::Symbol, solvekey::Symbol=:default)::VariableNodeData
554+
function getVariableSolverData(dfg::AbstractDFG, variablekey::Symbol, solvekey::Symbol=:default)
511555
v = getVariable(dfg, variablekey)
512556
!haskey(v.solverDataDict, solvekey) && error("Solve key '$solvekey' not found in variable '$variablekey'")
513557
return v.solverDataDict[solvekey]
@@ -518,7 +562,7 @@ end
518562
$(SIGNATURES)
519563
Add variable solver data, errors if it already exists.
520564
"""
521-
function addVariableSolverData!(dfg::AbstractDFG, variablekey::Symbol, vnd::VariableNodeData)::VariableNodeData
565+
function addVariableSolverData!(dfg::AbstractDFG, variablekey::Symbol, vnd::VariableNodeData)
522566
var = getVariable(dfg, variablekey)
523567
if haskey(var.solverDataDict, vnd.solveKey)
524568
error("VariableNodeData '$(vnd.solveKey)' already exists")
@@ -655,7 +699,7 @@ const deepcopySupersolve! = deepcopySolvekeys!
655699
$(SIGNATURES)
656700
Delete variable solver data, returns the deleted element.
657701
"""
658-
function deleteVariableSolverData!(dfg::AbstractDFG, variablekey::Symbol, solveKey::Symbol=:default)::VariableNodeData
702+
function deleteVariableSolverData!(dfg::AbstractDFG, variablekey::Symbol, solveKey::Symbol=:default)
659703
var = getVariable(dfg, variablekey)
660704

661705
if !haskey(var.solverDataDict, solveKey)
@@ -691,7 +735,7 @@ Merges and updates solver and estimate data for a variable (variable can be from
691735
If the same key is present in another collection, the value for that key will be the value it has in the last collection listed (updated).
692736
Note: Makes a copy of the estimates and solver data so that there is no coupling between graphs.
693737
"""
694-
function mergeVariableSolverData!(destVariable::DFGVariable, sourceVariable::DFGVariable)::DFGVariable
738+
function mergeVariableSolverData!(destVariable::DFGVariable, sourceVariable::DFGVariable)
695739
# We don't know which graph this came from, must be copied!
696740
merge!(destVariable.solverDataDict, deepcopy(sourceVariable.solverDataDict))
697741
return destVariable
@@ -786,7 +830,7 @@ end
786830
$(SIGNATURES)
787831
Delete PPE data, returns the deleted element.
788832
"""
789-
function deletePPE!(dfg::AbstractDFG, variablekey::Symbol, ppekey::Symbol=:default)::AbstractPointParametricEst
833+
function deletePPE!(dfg::AbstractDFG, variablekey::Symbol, ppekey::Symbol=:default)
790834
var = getVariable(dfg, variablekey)
791835

792836
if !haskey(var.ppeDict, ppekey)
@@ -811,9 +855,9 @@ deletePPE!(dfg::AbstractDFG, sourceVariable::DFGVariable, ppekey::Symbol=:defaul
811855
$(SIGNATURES)
812856
List all the PPE data keys in the variable.
813857
"""
814-
function listPPEs(dfg::AbstractDFG, variablekey::Symbol)::Vector{Symbol}
858+
function listPPEs(dfg::AbstractDFG, variablekey::Symbol)
815859
v = getVariable(dfg, variablekey)
816-
return collect(keys(v.ppeDict))
860+
return collect(keys(v.ppeDict))::Vector{Symbol}
817861
end
818862

819863
#TODO API and only correct level
@@ -822,7 +866,7 @@ end
822866
Merges and updates solver and estimate data for a variable (variable can be from another graph).
823867
Note: Makes a copy of the estimates and solver data so that there is no coupling between graphs.
824868
"""
825-
function mergePPEs!(destVariable::AbstractDFGVariable, sourceVariable::AbstractDFGVariable)::AbstractDFGVariable
869+
function mergePPEs!(destVariable::AbstractDFGVariable, sourceVariable::AbstractDFGVariable)
826870
# We don't know which graph this came from, must be copied!
827871
merge!(destVariable.ppeDict, deepcopy(sourceVariable.ppeDict))
828872
return destVariable

src/services/Serialization.jl

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ end
109109
##==============================================================================
110110
## Variable Packing and unpacking
111111
##==============================================================================
112-
function packVariable(dfg::G, v::DFGVariable)::Dict{String, Any} where G <: AbstractDFG
112+
function packVariable(dfg::G, v::DFGVariable) where G <: AbstractDFG
113113
props = Dict{String, Any}()
114114
props["label"] = string(v.label)
115115
props["timestamp"] = Dates.format(v.timestamp, "yyyy-mm-ddTHH:MM:SS.ssszzz")
@@ -123,14 +123,15 @@ function packVariable(dfg::G, v::DFGVariable)::Dict{String, Any} where G <: Abst
123123
props["dataEntry"] = JSON2.write(Dict(keys(v.dataDict) .=> map(bde -> JSON.json(bde), values(v.dataDict))))
124124
props["dataEntryType"] = JSON2.write(Dict(keys(v.dataDict) .=> map(bde -> typeof(bde), values(v.dataDict))))
125125
props["_version"] = _getDFGVersion()
126-
return props
126+
return props::Dict{String, Any}
127127
end
128128

129+
# returns a DFGVariable
129130
function unpackVariable(dfg::G,
130-
packedProps::Dict{String, Any};
131-
unpackPPEs::Bool=true,
132-
unpackSolverData::Bool=true,
133-
unpackBigData::Bool=true)::DFGVariable where G <: AbstractDFG
131+
packedProps::Dict{String, Any};
132+
unpackPPEs::Bool=true,
133+
unpackSolverData::Bool=true,
134+
unpackBigData::Bool=true) where G <: AbstractDFG
134135
@debug "Unpacking variable:\r\n$packedProps"
135136
# Version checking.
136137
_versionCheck(packedProps)
@@ -202,10 +203,23 @@ function unpackVariable(dfg::G,
202203
return variable
203204
end
204205

205-
function packVariableNodeData(dfg::G, d::VariableNodeData)::PackedVariableNodeData where G <: AbstractDFG
206+
## FIXME need method to serialize coordinates for the variable `vnd.val::Vector{P}` and `.bw.Vector{B}`
207+
208+
function _packVNDVal(T::Type, val::Vector{P}) where P
209+
210+
end
211+
212+
213+
# returns a PackedVariableNodeData
214+
function packVariableNodeData(::G, d::VariableNodeData{T}) where {G <: AbstractDFG, T <: InferenceVariable}
206215
@debug "Dispatching conversion variable -> packed variable for type $(string(d.variableType))"
207-
return PackedVariableNodeData(d.val[:],size(d.val,1),
208-
d.bw[:], size(d.bw,1),
216+
precast = getCoordinates.(T, d.val)
217+
@cast castval[i,j] := precast[j][i]
218+
_val = precast[:]
219+
@cast castbw[i,j] := d.bw[j][i]
220+
_bw = castbw[:]
221+
return PackedVariableNodeData(_val, size(castval,1),
222+
_bw, size(castbw,1),
209223
d.BayesNetOutVertIDs,
210224
d.dimIDs, d.dims, d.eliminated,
211225
d.BayesNetVertID, d.separator,
@@ -219,25 +233,28 @@ function packVariableNodeData(dfg::G, d::VariableNodeData)::PackedVariableNodeDa
219233
d.solveKey)
220234
end
221235

222-
function unpackVariableNodeData(dfg::G, d::PackedVariableNodeData)::VariableNodeData where G <: AbstractDFG
223-
r3 = d.dimval
224-
c3 = r3 > 0 ? floor(Int,length(d.vecval)/r3) : 0
225-
M3 = reshape(d.vecval,r3,c3)
226-
227-
r4 = d.dimbw
228-
c4 = r4 > 0 ? floor(Int,length(d.vecbw)/r4) : 0
229-
M4 = reshape(d.vecbw,r4,c4)
230-
231-
@debug "Dispatching conversion packed variable -> variable for type $(string(d.variableType))"
232-
# Figuring out the variableType
233-
# TODO deprecated remove in v0.11 - for backward compatibility for saved variableTypes.
234-
ststring = string(split(d.variableType, "(")[1])
235-
st = getTypeFromSerializationModule(ststring)
236-
isnothing(st) && error("The variable doesn't seem to have a variableType. It needs to set up with an InferenceVariable from IIF. This will happen if you use DFG to add serialized variables directly and try use them. Please use IncrementalInference.addVariable().")
237-
238-
return VariableNodeData{st}(M3,M4, d.BayesNetOutVertIDs,
239-
d.dimIDs, d.dims, d.eliminated, d.BayesNetVertID, d.separator,
240-
st(), d.initialized, d.inferdim, d.ismargin, d.dontmargin, d.solveInProgress, d.solvedCount, d.solveKey)
236+
function unpackVariableNodeData(dfg::G, d::PackedVariableNodeData) where G <: AbstractDFG
237+
@debug "Dispatching conversion packed variable -> variable for type $(string(d.variableType))"
238+
# Figuring out the variableType
239+
# TODO deprecated remove in v0.11 - for backward compatibility for saved variableTypes.
240+
ststring = string(split(d.variableType, "(")[1])
241+
T = getTypeFromSerializationModule(ststring)
242+
isnothing(T) && error("The variable doesn't seem to have a variableType. It needs to set up with an InferenceVariable from IIF. This will happen if you use DFG to add serialized variables directly and try use them. Please use IncrementalInference.addVariable().")
243+
244+
r3 = d.dimval
245+
c3 = r3 > 0 ? floor(Int,length(d.vecval)/r3) : 0
246+
M3 = reshape(d.vecval,r3,c3)
247+
@cast val_[j][i] := M3[i,j]
248+
vals = getPoint.(T, val_)
249+
250+
r4 = d.dimbw
251+
c4 = r4 > 0 ? floor(Int,length(d.vecbw)/r4) : 0
252+
M4 = reshape(d.vecbw,r4,c4)
253+
@cast bw[j][i] := M4[i,j]
254+
255+
return VariableNodeData{T}(vals, bw, d.BayesNetOutVertIDs,
256+
d.dimIDs, d.dims, d.eliminated, d.BayesNetVertID, d.separator,
257+
st(), d.initialized, d.inferdim, d.ismargin, d.dontmargin, d.solveInProgress, d.solvedCount, d.solveKey)
241258
end
242259

243260
##==============================================================================

0 commit comments

Comments
 (0)