Skip to content

Commit 3a2ecea

Browse files
committed
Refactoring for solver parameters, cleanup
1 parent b7f496f commit 3a2ecea

File tree

14 files changed

+116
-80
lines changed

14 files changed

+116
-80
lines changed

.travis.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ language: julia
33
os:
44
- linux
55

6+
services:
7+
- neo4j
8+
69
julia:
710
- 1.0
811
- 1.1

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ julia = "0.7, 1"
2323

2424
[extras]
2525
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
26+
IncrementalInference = "904591bb-b899-562f-9e6f-b8df64c7d480"
27+
RoME = "91fb55c2-4c03-5a59-ba21-f4ea956187b8"
2628

2729
[targets]
2830
test = ["Test"]

src/CloudGraphsDFG/entities/CloudGraphsDFG.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ mutable struct Neo4jInstance
55
graph::Neo4j.Graph
66
end
77

8-
mutable struct CloudGraphsDFG <: AbstractDFG
8+
mutable struct CloudGraphsDFG{T <: AbstractParams} <: AbstractDFG
99
neo4jInstance::Neo4jInstance
1010
description::String
1111
userId::String
@@ -18,7 +18,7 @@ mutable struct CloudGraphsDFG <: AbstractDFG
1818
variableCache::Dict{Symbol, DFGVariable}
1919
factorCache::Dict{Symbol, DFGFactor}
2020
addHistory::Vector{Symbol} #TODO: Discuss more - is this an audit trail?
21-
solverParams::Any # Solver parameters
21+
solverParams::T # Solver parameters
2222
useCache::Bool
2323
end
2424

src/CloudGraphsDFG/services/CloudGraphsDFG.jl

Lines changed: 13 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,18 @@ end
4141
$(SIGNATURES)
4242
Create a new CloudGraphs-based DFG factor graph using a Neo4j.Connection.
4343
"""
44-
function CloudGraphsDFG(neo4jConnection::Neo4j.Connection, userId::String, robotId::String, sessionId::String, encodePackedTypeFunc, getPackedTypeFunc, decodePackedTypeFunc; description::String="CloudGraphs DFG", solverParams::Any=nothing, useCache::Bool=false)
44+
function CloudGraphsDFG{T}(neo4jConnection::Neo4j.Connection, userId::String, robotId::String, sessionId::String, encodePackedTypeFunc, getPackedTypeFunc, decodePackedTypeFunc; description::String="CloudGraphs DFG", solverParams::T=NoSolverParams(), useCache::Bool=false) where T <: AbstractParams
4545
graph = Neo4j.getgraph(neo4jConnection)
4646
neo4jInstance = Neo4jInstance(neo4jConnection, graph)
47-
return CloudGraphsDFG(neo4jInstance, description, userId, robotId, sessionId, encodePackedTypeFunc, getPackedTypeFunc, decodePackedTypeFunc, Dict{Symbol, Int64}(), Dict{Symbol, DFGVariable}(), Dict{Symbol, DFGFactor}(), Symbol[], solverParams, useCache)
47+
return CloudGraphsDFG{T}(neo4jInstance, description, userId, robotId, sessionId, encodePackedTypeFunc, getPackedTypeFunc, decodePackedTypeFunc, Dict{Symbol, Int64}(), Dict{Symbol, DFGVariable}(), Dict{Symbol, DFGFactor}(), Symbol[], solverParams, useCache)
4848
end
4949
"""
5050
$(SIGNATURES)
5151
Create a new CloudGraphs-based DFG factor graph by specifying the Neo4j connection information.
5252
"""
53-
function CloudGraphsDFG(host::String, port::Int, dbUser::String, dbPassword::String, userId::String, robotId::String, sessionId::String, encodePackedTypeFunc, getPackedTypeFunc, decodePackedTypeFunc; description::String="CloudGraphs DFG", solverParams::Any=nothing, useCache::Bool=false)
53+
function CloudGraphsDFG{T}(host::String, port::Int, dbUser::String, dbPassword::String, userId::String, robotId::String, sessionId::String, encodePackedTypeFunc, getPackedTypeFunc, decodePackedTypeFunc; description::String="CloudGraphs DFG", solverParams::T=NoSolverParams(), useCache::Bool=false) where T <: AbstractParams
5454
neo4jConnection = Neo4j.Connection(host, port=port, user=dbUser, password=dbPassword);
55-
return CloudGraphsDFG(neo4jConnection, userId, robotId, sessionId, encodePackedTypeFunc, getPackedTypeFunc, decodePackedTypeFunc, description=description, solverParams=solverParams, useCache=useCache)
55+
return CloudGraphsDFG{T}(neo4jConnection, userId, robotId, sessionId, encodePackedTypeFunc, getPackedTypeFunc, decodePackedTypeFunc, description=description, solverParams=solverParams, useCache=useCache)
5656
end
5757

5858
"""
@@ -68,7 +68,7 @@ function _getDuplicatedEmptyDFG(dfg::CloudGraphsDFG)::CloudGraphsDFG
6868
length(_getLabelsFromCyphonQuery(dfg.neo4jInstance, "(node:$(dfg.userId):$(dfg.robotId):$(sessionId))")) == 0 && break
6969
end
7070
@debug "Unique+empty copy session name: $sessionId"
71-
return CloudGraphsDFG(dfg.neo4jInstance.connection, dfg.userId, dfg.robotId, sessionId, dfg.encodePackedTypeFunc, dfg.getPackedTypeFunc, dfg.decodePackedTypeFunc, solverParams=deepcopy(dfg.solverParams), description="(Copy of) $(dfg.description)", useCache=dfg.useCache)
71+
return CloudGraphsDFG{typeof(dfg.solverParams)}(dfg.neo4jInstance.connection, dfg.userId, dfg.robotId, sessionId, dfg.encodePackedTypeFunc, dfg.getPackedTypeFunc, dfg.decodePackedTypeFunc, solverParams=deepcopy(dfg.solverParams), description="(Copy of) $(dfg.description)", useCache=dfg.useCache)
7272
end
7373

7474
# Accessors
@@ -77,7 +77,9 @@ getDescription(dfg::CloudGraphsDFG) = dfg.description
7777
setDescription(dfg::CloudGraphsDFG, description::String) = dfg.description = description
7878
getAddHistory(dfg::CloudGraphsDFG) = dfg.addHistory
7979
getSolverParams(dfg::CloudGraphsDFG) = dfg.solverParams
80-
setSolverParams(dfg::CloudGraphsDFG, solverParams::Any) = dfg.solverParams = solverParams
80+
function setSolverParams(dfg::CloudGraphsDFG, solverParams::T) where T <: AbstractParams
81+
dfg.solverParams = solverParams
82+
end
8183

8284
"""
8385
$(SIGNATURES)
@@ -440,16 +442,12 @@ function updateFactor!(dfg::CloudGraphsDFG, variables::Vector{DFGVariable}, fact
440442
# Update the body
441443
factor = updateFactor!(dfg, factor)
442444

443-
@show "HERE WITH $(factor.label)!"
444-
@show map(v->v.label, variables)
445-
446445
# Now update the relationships
447-
@show existingNeighbors = getNeighbors(dfg, factor)
446+
existingNeighbors = getNeighbors(dfg, factor)
448447
if symdiff(existingNeighbors, map(v->v.label, variables)) == []
449448
# Done, otherwise we need to remake the edges.
450449
return factor
451450
end
452-
@show "HERE WITH $(factor.label)!"
453451
# Delete existing relationships
454452
fNode = Neo4j.getnode(dfg.neo4jInstance.graph, factor._internalId)
455453
for relationship in Neo4j.getrels(fNode)
@@ -707,39 +705,8 @@ function ls(dfg::CloudGraphsDFG, label::Symbol)::Vector{Symbol} where T <: DFGNo
707705
return getNeighbors(dfg, label)
708706
end
709707

710-
function _copyIntoGraph!(sourceDFG::CloudGraphsDFG, destDFG::CloudGraphsDFG, variableFactorLabels::Vector{Symbol}, includeOrphanFactors::Bool=false)::Nothing
711-
@show variableFactorLabels
712-
# Split into variables and factors
713-
sourceVariables = map(vId->getVariable(sourceDFG, vId), intersect(getVariableIds(sourceDFG), variableFactorLabels))
714-
sourceFactors = map(fId->getFactor(sourceDFG, fId), intersect(getFactorIds(sourceDFG), variableFactorLabels))
715-
if length(sourceVariables) + length(sourceFactors) != length(variableFactorLabels)
716-
rem = symdiff(map(v->v.label, sourceVariables), variableFactorLabels)
717-
rem = symdiff(map(f->f.label, sourceFactors), variableFactorLabels)
718-
error("Cannot copy because cannot find the following nodes in the source graph: $rem")
719-
end
720-
721-
# Now we have to add all variables first,
722-
for variable in sourceVariables
723-
addVariable!(destDFG, deepcopy(variable))
724-
end
725-
# And then all factors to the destDFG.
726-
for factor in sourceFactors
727-
# Get the original factor variables (we need them to create it)
728-
sourceFactorVariableIds = getNeighbors(sourceDFG, factor)
729-
# Find the labels and associated variables in our new subgraph
730-
factVariableIds = Symbol[]
731-
for variable in sourceFactorVariableIds
732-
if exists(destDFG, variable)
733-
push!(factVariableIds, variable)
734-
end
735-
end
736-
# Only if we have all of them should we add it (otherwise strange things may happen on evaluation)
737-
if includeOrphanFactors || length(factVariableIds) == length(sourceFactorVariableIds)
738-
addFactor!(destDFG, factVariableIds, deepcopy(factor))
739-
end
740-
end
741-
return nothing
742-
end
708+
## This is moved to services/AbstractDFG.jl
709+
# function _copyIntoGraph!(sourceDFG::CloudGraphsDFG, destDFG::CloudGraphsDFG, variableFactorLabels::Vector{Symbol}, includeOrphanFactors::Bool=false)::Nothing
743710

744711
"""
745712
$(SIGNATURES)
@@ -809,8 +776,8 @@ function getAdjacencyMatrix(dfg::CloudGraphsDFG)::Matrix{Union{Nothing, Symbol}}
809776
error(string(nodes.errors))
810777
end
811778
# Add in the relationships
812-
@show varRel = Symbol.(map(node -> node["row"][1], nodes.results[1]["data"]))
813-
@show factRel = Symbol.(map(node -> node["row"][2], nodes.results[1]["data"]))
779+
varRel = Symbol.(map(node -> node["row"][1], nodes.results[1]["data"]))
780+
factRel = Symbol.(map(node -> node["row"][2], nodes.results[1]["data"]))
814781
for i = 1:length(varRel)
815782
adjMat[fDict[factRel[i]], vDict[varRel[i]]] = factRel[i]
816783
end

src/GraphsDFG/entities/GraphsDFG.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,7 @@ mutable struct GraphsNode
88
end
99
const FGType = Graphs.GenericIncidenceList{GraphsNode,Graphs.Edge{GraphsNode},Dict{Int,GraphsNode},Dict{Int,Array{Graphs.Edge{GraphsNode},1}}}
1010

11-
abstract type AbstractParams end
12-
13-
mutable struct NoSolverParams <: AbstractParams
14-
end
15-
16-
mutable struct GraphsDFG{T<:AbstractParams} <: AbstractDFG
11+
mutable struct GraphsDFG{T <: AbstractParams} <: AbstractDFG
1712
g::FGType
1813
description::String
1914
nodeCounter::Int64

src/entities/AbstractTypes.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,8 @@ Abstract parent struct for a DFG graph.
1212
"""
1313
abstract type AbstractDFG
1414
end
15+
16+
abstract type AbstractParams end
17+
18+
mutable struct NoSolverParams <: AbstractParams
19+
end

src/entities/DFGVariable.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ mutable struct VariableNodeData
1515
partialinit::Bool
1616
ismargin::Bool
1717
dontmargin::Bool
18-
VariableNodeData() = new()
18+
# A valid, packable default constructor is needed.
19+
VariableNodeData() = new(zeros(1,1), zeros(1,1), Symbol[], Int[], 0, false, :NOTHING, Symbol[], "", false, false, false, false)
1920
VariableNodeData(x1::Array{Float64,2},
2021
x2::Array{Float64,2},
2122
x3::Vector{Symbol},

src/services/AbstractDFG.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,41 @@ function getSerializationModule(dfg::G)::Module where G <: AbstractDFG
2424
@warn "Retrieving serialization module from AbstractDFG - override this in the '$(typeof(dfg)) structure! This is returning Main"
2525
return Main
2626
end
27+
28+
"""
29+
$(SIGNATURES)
30+
Common function for copying nodes from one graph into another graph.
31+
This is overridden in specialized implementations for performance.
32+
"""
33+
function _copyIntoGraph!(sourceDFG::G, destDFG::H, variableFactorLabels::Vector{Symbol}, includeOrphanFactors::Bool=false)::Nothing where {G <: AbstractDFG, H <: AbstractDFG}
34+
# Split into variables and factors
35+
sourceVariables = map(vId->getVariable(sourceDFG, vId), intersect(getVariableIds(sourceDFG), variableFactorLabels))
36+
sourceFactors = map(fId->getFactor(sourceDFG, fId), intersect(getFactorIds(sourceDFG), variableFactorLabels))
37+
if length(sourceVariables) + length(sourceFactors) != length(variableFactorLabels)
38+
rem = symdiff(map(v->v.label, sourceVariables), variableFactorLabels)
39+
rem = symdiff(map(f->f.label, sourceFactors), variableFactorLabels)
40+
error("Cannot copy because cannot find the following nodes in the source graph: $rem")
41+
end
42+
43+
# Now we have to add all variables first,
44+
for variable in sourceVariables
45+
addVariable!(destDFG, deepcopy(variable))
46+
end
47+
# And then all factors to the destDFG.
48+
for factor in sourceFactors
49+
# Get the original factor variables (we need them to create it)
50+
sourceFactorVariableIds = getNeighbors(sourceDFG, factor)
51+
# Find the labels and associated variables in our new subgraph
52+
factVariableIds = Symbol[]
53+
for variable in sourceFactorVariableIds
54+
if exists(destDFG, variable)
55+
push!(factVariableIds, variable)
56+
end
57+
end
58+
# Only if we have all of them should we add it (otherwise strange things may happen on evaluation)
59+
if includeOrphanFactors || length(factVariableIds) == length(sourceFactorVariableIds)
60+
addFactor!(destDFG, factVariableIds, deepcopy(factor))
61+
end
62+
end
63+
return nothing
64+
end

src/services/DFGVariable.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ function pack(dfg::G, d::VariableNodeData)::PackedVariableNodeData where G <: Ab
88
d.BayesNetOutVertIDs,
99
d.dimIDs, d.dims, d.eliminated,
1010
d.BayesNetVertID, d.separator,
11-
string(d.softtype), d.initialized, d.partialinit, d.ismargin, d.dontmargin)
11+
d.softtype != nothing ? string(d.softtype) : nothing, d.initialized, d.partialinit, d.ismargin, d.dontmargin)
1212
end
1313

1414
function unpack(dfg::G, d::PackedVariableNodeData)::VariableNodeData where G <: AbstractDFG
@@ -26,10 +26,12 @@ function unpack(dfg::G, d::PackedVariableNodeData)::VariableNodeData where G <:
2626
mainmod = getSerializationModule(dfg)
2727
mainmod == nothing && error("Serialization module is null - please call setSerializationNamespace!(\"Main\" => Main) in your main program.")
2828
try
29-
unpackedTypeName = split(d.softtype, "(")[1]
30-
unpackedTypeName = split(unpackedTypeName, '.')[end]
31-
@debug "DECODING Softtype = $unpackedTypeName"
32-
st = getfield(mainmod, Symbol(unpackedTypeName))()
29+
if d.softtype != ""
30+
unpackedTypeName = split(d.softtype, "(")[1]
31+
unpackedTypeName = split(unpackedTypeName, '.')[end]
32+
@debug "DECODING Softtype = $unpackedTypeName"
33+
st = getfield(mainmod, Symbol(unpackedTypeName))()
34+
end
3335
catch ex
3436
@error "Unable to deserialize soft type $(d.softtype)"
3537
io = IOBuffer()

src/services/LightDFGraph.jl

Lines changed: 0 additions & 2 deletions
This file was deleted.

0 commit comments

Comments
 (0)