Skip to content

Commit ddd4a99

Browse files
authored
Merge pull request #61 from JuliaRobotics/feature/cloudgraphdfg
CloudGraphs rewritten for DFG is now working
2 parents 4d1fbe2 + 2234397 commit ddd4a99

File tree

7 files changed

+146
-13
lines changed

7 files changed

+146
-13
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
99
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1010
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1111
JSON2 = "2535ab7d-5cd8-5a07-80ac-9b1792aadce3"
12+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1213
Neo4j = "d2adbeaf-5838-5367-8a2f-e46d570981db"
1314
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1415
Requires = "ae029012-a4dd-5104-9daa-d747884805df"

src/CloudGraphsDFG/entities/CloudGraphsDFG.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ mutable struct CloudGraphsDFG{T <: AbstractParams} <: AbstractDFG
1414
encodePackedTypeFunc
1515
getPackedTypeFunc
1616
decodePackedTypeFunc
17+
rebuildFactorMetadata!
1718
labelDict::Dict{Symbol, Int64}
1819
variableCache::Dict{Symbol, DFGVariable}
1920
factorCache::Dict{Symbol, DFGFactor}
@@ -23,8 +24,8 @@ mutable struct CloudGraphsDFG{T <: AbstractParams} <: AbstractDFG
2324
end
2425

2526
function show(io::IO, c::CloudGraphsDFG)
26-
println("CloudGraphsDFG:")
27-
println(" - Neo4J instance: $(c.neo4jInstance.connection.host)")
28-
println(" - Session: $(c.userId):$(c.robotId):$(c.sessionId)")
29-
println(" - Caching: $(c.useCache)")
27+
println(io, "CloudGraphsDFG:")
28+
println(io, " - Neo4J instance: $(c.neo4jInstance.connection.host)")
29+
println(io, " - Session: $(c.userId):$(c.robotId):$(c.sessionId)")
30+
println(io, " - Caching: $(c.useCache)")
3031
end

src/CloudGraphsDFG/services/CloudGraphsDFG.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,18 @@ end
4141
$(SIGNATURES)
4242
Create a new CloudGraphs-based DFG factor graph using a Neo4j.Connection.
4343
"""
44-
function CloudGraphsDFG{T}(neo4jConnection::Neo4j.Connection, userId::String, robotId::String, sessionId::String, encodePackedTypeFunc, getPackedTypeFunc, decodePackedTypeFunc; description::String="CloudGraphs DFG", solverParams::T=NoSolverParams(), useCache::Bool=false) where T <: AbstractParams
44+
function CloudGraphsDFG{T}(neo4jConnection::Neo4j.Connection, userId::String, robotId::String, sessionId::String, encodePackedTypeFunc, getPackedTypeFunc, decodePackedTypeFunc, rebuildFactorMetadata!; description::String="CloudGraphs DFG", solverParams::T=NoSolverParams(), useCache::Bool=false) where T <: AbstractParams
4545
graph = Neo4j.getgraph(neo4jConnection)
4646
neo4jInstance = Neo4jInstance(neo4jConnection, graph)
47-
return CloudGraphsDFG{T}(neo4jInstance, description, userId, robotId, sessionId, encodePackedTypeFunc, getPackedTypeFunc, decodePackedTypeFunc, Dict{Symbol, Int64}(), Dict{Symbol, DFGVariable}(), Dict{Symbol, DFGFactor}(), Symbol[], solverParams, useCache)
47+
return CloudGraphsDFG{T}(neo4jInstance, description, userId, robotId, sessionId, encodePackedTypeFunc, getPackedTypeFunc, decodePackedTypeFunc, rebuildFactorMetadata!, Dict{Symbol, Int64}(), Dict{Symbol, DFGVariable}(), Dict{Symbol, DFGFactor}(), Symbol[], solverParams, useCache)
4848
end
4949
"""
5050
$(SIGNATURES)
5151
Create a new CloudGraphs-based DFG factor graph by specifying the Neo4j connection information.
5252
"""
53-
function CloudGraphsDFG{T}(host::String, port::Int, dbUser::String, dbPassword::String, userId::String, robotId::String, sessionId::String, encodePackedTypeFunc, getPackedTypeFunc, decodePackedTypeFunc; description::String="CloudGraphs DFG", solverParams::T=NoSolverParams(), useCache::Bool=false) where T <: AbstractParams
53+
function CloudGraphsDFG{T}(host::String, port::Int, dbUser::String, dbPassword::String, userId::String, robotId::String, sessionId::String, encodePackedTypeFunc, getPackedTypeFunc, decodePackedTypeFunc, rebuildFactorMetadata!; description::String="CloudGraphs DFG", solverParams::T=NoSolverParams(), useCache::Bool=false) where T <: AbstractParams
5454
neo4jConnection = Neo4j.Connection(host, port=port, user=dbUser, password=dbPassword);
55-
return CloudGraphsDFG{T}(neo4jConnection, userId, robotId, sessionId, encodePackedTypeFunc, getPackedTypeFunc, decodePackedTypeFunc, description=description, solverParams=solverParams, useCache=useCache)
55+
return CloudGraphsDFG{T}(neo4jConnection, userId, robotId, sessionId, encodePackedTypeFunc, getPackedTypeFunc, decodePackedTypeFunc, rebuildFactorMetadata!, description=description, solverParams=solverParams, useCache=useCache)
5656
end
5757

5858
"""
@@ -323,10 +323,10 @@ function getFactor(dfg::CloudGraphsDFG, factorId::Int64)::DFGFactor
323323

324324
data = props["data"]
325325
datatype = props["fnctype"]
326-
fulltype = getfield(Main, Symbol(datatype))
326+
# fulltype = getfield(Main, Symbol(datatype))
327327
packtype = getfield(Main, Symbol("Packed"*datatype))
328328
packed = JSON2.read(data, GenericFunctionNodeData{packtype,String})
329-
fullFactor = dfg.decodePackedTypeFunc(packed, "")
329+
fullFactor = dfg.decodePackedTypeFunc(dfg, packed)
330330

331331
# Include the type
332332
_variableOrderSymbols = JSON2.read(props["_variableOrderSymbols"], Vector{Symbol})
@@ -341,6 +341,12 @@ function getFactor(dfg::CloudGraphsDFG, factorId::Int64)::DFGFactor
341341
factor.ready = ready
342342
factor.backendset = backendset
343343

344+
# Lastly, rebuild the metadata
345+
factor = dfg.rebuildFactorMetadata!(dfg, factor)
346+
# GUARANTEED never to bite us in the ass in the future...
347+
# ... TODO: refactor if changed: https://github.com/JuliaRobotics/IncrementalInference.jl/issues/350
348+
getData(factor).fncargvID = _variableOrderSymbols
349+
344350
# Add to cache
345351
push!(dfg.factorCache, factor.label=>factor)
346352

src/DistributedFactorGraphs.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using Dates
77
using Distributions
88
using Reexport
99
using JSON2
10+
using LinearAlgebra
1011

1112
# Entities
1213
include("entities/AbstractTypes.jl")
@@ -33,6 +34,7 @@ export GenericFunctionNodeData#, FunctionNodeData
3334
export getSerializationModule, setSerializationModule!
3435
export pack, unpack
3536

37+
# Common includes
3638
include("services/AbstractDFG.jl")
3739
include("services/DFGVariable.jl")
3840

src/GraphsDFG/services/GraphsDFG.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,15 @@ function addFactor!(dfg::GraphsDFG, variables::Vector{DFGVariable}, factor::DFGF
8989
return true
9090
end
9191

92+
"""
93+
$(SIGNATURES)
94+
Add a DFGFactor to a DFG.
95+
"""
96+
function addFactor!(dfg::GraphsDFG, variableIds::Vector{Symbol}, factor::DFGFactor)::Bool
97+
variables = map(vId -> getVariable(dfg, vId), variableIds)
98+
return addFactor!(dfg, variables, factor)
99+
end
100+
92101
"""
93102
$(SIGNATURES)
94103
Get a DFGVariable from a DFG using its underlying integer ID.

src/entities/DFGFactor.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,6 @@ mutable struct DFGFactor{T, S} <: DFGNode
4747
DFGFactor{T, S}(label::Symbol, _internalId::Int64) where {T, S} = new{T, S}(label, Symbol[], GenericFunctionNodeData{T, S}(), 0, 0, _internalId, Symbol[])
4848
end
4949

50-
# const FunctionNodeData{T} = GenericFunctionNodeData{T, Symbol}
51-
# FunctionNodeData(x1, x2, x3, x4, x5::Symbol, x6::T, x7::String="", x8::Vector{Int}=Int[]) where {T <: Union{FunctorInferenceType, ConvolutionObject}}= GenericFunctionNodeData{T, Symbol}(x1, x2, x3, x4, x5, x6, x7, x8)
52-
5350
label(f::F) where F <: DFGFactor = f.label
5451
data(f::F) where F <: DFGFactor = f.data
5552
id(f::F) where F <: DFGFactor = f._internalId

test/HexagonalCloud.jl

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
using Revise
2+
using Neo4j # So that DFG initializes the database driver.
3+
using RoME
4+
using DistributedFactorGraphs
5+
using Test
6+
7+
# start with an empty factor graph object
8+
# fg = initfg()
9+
cloudFg = CloudGraphsDFG{SolverParams}("localhost", 7474, "neo4j", "test",
10+
"testUser", "testRobot", "testSession",
11+
nothing,
12+
nothing,
13+
IncrementalInference.decodePackedType,
14+
IncrementalInference.rebuildFactorMetadata!,
15+
solverParams=SolverParams())
16+
# cloudFg = GraphsDFG{SolverParams}(params=SolverParams())
17+
# cloudFg = GraphsDFG{SolverParams}(params=SolverParams())
18+
clearSession!!(cloudFg)
19+
# cloudFg = initfg()
20+
21+
# Add the first pose :x0
22+
x0 = addVariable!(cloudFg, :x0, Pose2)
23+
IncrementalInference.compareVariable(x0, getVariable(cloudFg, :x0))
24+
25+
# Add at a fixed location PriorPose2 to pin :x0 to a starting location (10,10, pi/4)
26+
prior = addFactor!(cloudFg, [:x0], PriorPose2( MvNormal([10; 10; 1.0/8.0], Matrix(Diagonal([0.1;0.1;0.05].^2))) ) )
27+
# retPrior = getFactor(cloudFg, :x0f1)
28+
# Do the check
29+
# IncrementalInference.compareFactor(prior, retPrior)
30+
# Testing
31+
32+
# retPrior.data.fnc.cpt = prior.data.fnc.cpt
33+
# # This one
34+
# prior.data.fnc.cpt[1].factormetadata
35+
# deserialized: Any[Pose2(3, String[], (:Euclid, :Euclid, :Circular))]
36+
# vs.
37+
# original: Pose2[Pose2(3, String[], (:Euclid, :Euclid, :Circular))]
38+
39+
# Drive around in a hexagon in the cloud
40+
for i in 0:5
41+
psym = Symbol("x$i")
42+
nsym = Symbol("x$(i+1)")
43+
addVariable!(cloudFg, nsym, Pose2)
44+
pp = Pose2Pose2(MvNormal([10.0;0;pi/3], Matrix(Diagonal([0.1;0.1;0.1].^2))))
45+
addFactor!(cloudFg, [psym;nsym], pp )
46+
end
47+
48+
# Right, let's copy it into local memory for solving...
49+
localFg = GraphsDFG{SolverParams}(params=SolverParams())
50+
DistributedFactorGraphs._copyIntoGraph!(cloudFg, localFg, union(getVariableIds(cloudFg), getFactorIds(cloudFg)), true)
51+
# Some checks
52+
@test symdiff(getVariableIds(localFg), getVariableIds(cloudFg)) == []
53+
@test symdiff(getFactorIds(localFg), getFactorIds(cloudFg)) == []
54+
@test isFullyConnected(localFg)
55+
# Show it
56+
toDotFile(localFg, "/tmp/localfg.dot")
57+
58+
# Alrighty! At this point, we should be able to solve locally...
59+
# perform inference, and remember first runs are slower owing to Julia's just-in-time compiling
60+
# Can do with graph too!
61+
tree, smt, hist = solveTree!(localFg)
62+
63+
wipeBuildNewTree!(localFg)
64+
tree, smt, hist = solveTree!(localFg, tree) # Recycle
65+
# batchSolve!(localFg, drawpdf=true, show=true)
66+
# Erm, whut? Error = mcmcIterationIDs -- unaccounted variables
67+
68+
# Trying new method.
69+
tree, smtasks = batchSolve!(localFg, treeinit=true, drawpdf=true, show=true,
70+
returntasks=true, limititers=50,
71+
upsolve=true, downsolve=true )
72+
73+
#### WIP and general debugging
74+
75+
# Testing with GenericMarginal
76+
# This will not work because GenericMarginal *shouldn't* really be persisted.
77+
# That would mean we're decomposing the cloud graph...
78+
# genmarg = GenericMarginal()
79+
# Xi = [getVariable(fg, :x0)]
80+
# addFactor!(fg, Xi, genmarg, autoinit=false)
81+
82+
# For Juno/Jupyter style use
83+
pl = drawPoses(localFg, meanmax=:mean)
84+
plotPose(fg, :x6)
85+
# For scripting use-cases you can export the image
86+
Gadfly.draw(Gadfly.PDF("/tmp/test1.pdf", 20cm, 10cm),pl) # or PNG(...)
87+
88+
89+
# Add landmarks with Bearing range measurements
90+
addVariable!(fg, :l1, Point2, labels=["LANDMARK"])
91+
p2br = Pose2Point2BearingRange(Normal(0,0.1),Normal(20.0,1.0))
92+
addFactor!(fg, [:x0; :l1], p2br )
93+
94+
95+
# Initialize :l1 numerical values but do not rerun solver
96+
ensureAllInitialized!(fg)
97+
pl = drawPosesLandms(fg)
98+
Gadfly.draw(Gadfly.PDF("/tmp/test2.pdf", 20cm, 10cm),pl) # or PNG(...)
99+
100+
101+
# Add landmarks with Bearing range measurements
102+
p2br2 = Pose2Point2BearingRange(Normal(0,0.1),Normal(20.0,1.0))
103+
addFactor!(fg, [:x6; :l1], p2br2 )
104+
105+
106+
# solve
107+
batchSolve!(fg, drawpdf=true)
108+
109+
110+
# redraw
111+
pl = drawPosesLandms(fg, meanmax=:mean)
112+
Gadfly.draw(Gadfly.PDF("/tmp/test3.pdf", 20cm, 10cm),pl) # or PNG(...)
113+
114+
115+
116+
117+
#

0 commit comments

Comments
 (0)