Skip to content

Commit 9775fe1

Browse files
committed
Add id to NodeData
1 parent 85e8a31 commit 9775fe1

File tree

3 files changed

+35
-30
lines changed

3 files changed

+35
-30
lines changed

src/graph_engine.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ Base.broadcastable(label::NodeLabel) = Ref(label)
255255

256256
getname(label::NodeLabel) = label.name
257257
getname(labels::ResizableArray{T, V, N} where {T <: NodeLabel, V, N}) = getname(first(labels))
258+
getid(label::NodeLabel) = label.global_counter
258259
iterate(label::NodeLabel) = (label, nothing)
259260
iterate(label::NodeLabel, any) = nothing
260261

@@ -765,9 +766,10 @@ mutable struct NodeData
765766
const context :: Context
766767
const properties :: Union{VariableNodeProperties, FactorNodeProperties{NodeData}}
767768
const extra :: UnorderedDictionary{Symbol, Any}
769+
const id :: Int
768770
end
769771

770-
NodeData(context, properties) = NodeData(context, properties, UnorderedDictionary{Symbol, Any}())
772+
NodeData(context, properties, id) = NodeData(context, properties, UnorderedDictionary{Symbol, Any}(), id)
771773

772774
function Base.show(io::IO, nodedata::NodeData)
773775
context = getcontext(nodedata)
@@ -1529,7 +1531,7 @@ end
15291531
function __add_variable_node!(model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index)
15301532
# In theory plugins are able to overwrite this
15311533
potential_label = generate_nodelabel(model, name)
1532-
potential_nodedata = NodeData(context, convert(VariableNodeProperties, name, index, options))
1534+
potential_nodedata = NodeData(context, convert(VariableNodeProperties, name, index, options), getid(potential_label))
15331535
label, nodedata = preprocess_plugins(
15341536
UnionPluginType(VariableNodePlugin(), FactorAndVariableNodesPlugin()), model, context, potential_label, potential_nodedata, options
15351537
)
@@ -1643,7 +1645,7 @@ function add_atomic_factor_node!(model::Model, context::Context, options::NodeCr
16431645
factornode_id = generate_factor_nodelabel(context, fform)
16441646

16451647
potential_label = generate_nodelabel(model, fform)
1646-
potential_nodedata = NodeData(context, convert(FactorNodeProperties, fform, options))
1648+
potential_nodedata = NodeData(context, convert(FactorNodeProperties, fform, options), getid(potential_label))
16471649

16481650
label, nodedata = preprocess_plugins(
16491651
UnionPluginType(FactorNodePlugin(), FactorAndVariableNodesPlugin()), model, context, potential_label, potential_nodedata, options

test/graph_engine_tests.jl

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ end
118118

119119
@testset "FactorNodeProperties" begin
120120
properties = FactorNodeProperties(fform = String)
121-
nodedata = NodeData(context, properties)
121+
nodedata = NodeData(context, properties, 1)
122122

123123
@test getcontext(nodedata) === context
124124
@test getproperties(nodedata) === properties
@@ -135,7 +135,7 @@ end
135135

136136
@testset "VariableNodeProperties" begin
137137
properties = VariableNodeProperties(name = :x, index = 1)
138-
nodedata = NodeData(context, properties)
138+
nodedata = NodeData(context, properties, 1)
139139

140140
@test getcontext(nodedata) === context
141141
@test getproperties(nodedata) === properties
@@ -183,7 +183,7 @@ end
183183
context = getcontext(model)
184184

185185
@testset for properties in (FactorNodeProperties(fform = String), VariableNodeProperties(name = :x, index = 1))
186-
nodedata = NodeData(context, properties)
186+
nodedata = NodeData(context, properties, 1)
187187

188188
@test !hasextra(nodedata, :a)
189189
@test getextra(nodedata, :a, 2) === 2
@@ -552,7 +552,10 @@ end
552552

553553
function GraphPPL.preprocess_plugin(::AnArbitraryPluginForChangingOptions, model, context, label, nodedata, options)
554554
# Here we replace the original options entirely
555-
return label, NodeData(context, convert(VariableNodeProperties, :x, nothing, NodeCreationOptions(kind = :constant, value = 1.0)))
555+
return label,
556+
NodeData(
557+
context, convert(VariableNodeProperties, :x, nothing, NodeCreationOptions(kind = :constant, value = 1.0)), GraphPPL.getid(label)
558+
)
556559
end
557560

558561
for model_fn in ModelsInTheZooWithoutArguments
@@ -933,13 +936,13 @@ end
933936

934937
model = create_test_model()
935938
ctx = getcontext(model)
936-
model[NodeLabel(, 1)] = NodeData(ctx, VariableNodeProperties(name = , index = nothing))
939+
model[NodeLabel(, 1)] = NodeData(ctx, VariableNodeProperties(name = , index = nothing), 1)
937940
@test nv(model) == 1 && ne(model) == 0
938941

939-
model[NodeLabel(:x, 2)] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing))
942+
model[NodeLabel(:x, 2)] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing), 2)
940943
@test nv(model) == 2 && ne(model) == 0
941944

942-
model[NodeLabel(sum, 3)] = NodeData(ctx, FactorNodeProperties(fform = sum))
945+
model[NodeLabel(sum, 3)] = NodeData(ctx, FactorNodeProperties(fform = sum), 3)
943946
@test nv(model) == 3 && ne(model) == 0
944947

945948
@test_throws MethodError model[0] = 1
@@ -959,8 +962,8 @@ end
959962
μ = NodeLabel(, 1)
960963
xref = NodeLabel(:x, 2)
961964

962-
model[μ] = NodeData(ctx, VariableNodeProperties(name = , index = nothing))
963-
model[xref] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing))
965+
model[μ] = NodeData(ctx, VariableNodeProperties(name = , index = nothing), 1)
966+
model[xref] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing), 2)
964967
model[μ, xref] = EdgeLabel(:interface, 1)
965968

966969
@test ne(model) == 1
@@ -990,7 +993,7 @@ end
990993
model = create_test_model()
991994
ctx = getcontext(model)
992995
label = NodeLabel(:x, 1)
993-
model[label] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing))
996+
model[label] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing), 1)
994997
@test isa(model[label], NodeData)
995998
@test isa(getproperties(model[label]), VariableNodeProperties)
996999
@test_throws KeyError model[NodeLabel(:x, 10)]
@@ -1024,8 +1027,8 @@ end
10241027
@test nv(model) == 0
10251028
@test ne(model) == 0
10261029

1027-
model[NodeLabel(:a, 1)] = NodeData(ctx, VariableNodeProperties(name = :a, index = nothing))
1028-
model[NodeLabel(:b, 2)] = NodeData(ctx, VariableNodeProperties(name = :b, index = nothing))
1030+
model[NodeLabel(:a, 1)] = NodeData(ctx, VariableNodeProperties(name = :a, index = nothing), 1)
1031+
model[NodeLabel(:b, 2)] = NodeData(ctx, VariableNodeProperties(name = :b, index = nothing), 2)
10291032
@test !isempty(model)
10301033
@test nv(model) == 2
10311034
@test ne(model) == 0
@@ -1059,8 +1062,8 @@ end
10591062
ctx = getcontext(model)
10601063
a = NodeLabel(:a, 1)
10611064
b = NodeLabel(:b, 2)
1062-
model[a] = NodeData(ctx, VariableNodeProperties(name = :a, index = nothing))
1063-
model[b] = NodeData(ctx, FactorNodeProperties(fform = sum))
1065+
model[a] = NodeData(ctx, VariableNodeProperties(name = :a, index = nothing), 1)
1066+
model[b] = NodeData(ctx, FactorNodeProperties(fform = sum), 2)
10641067
@test !has_edge(model, a, b)
10651068
@test !has_edge(model, b, a)
10661069
add_edge!(model, b, getproperties(model[b]), a, :edge, 1)
@@ -1069,7 +1072,7 @@ end
10691072
@test length(edges(model)) == 1
10701073

10711074
c = NodeLabel(:c, 2)
1072-
model[c] = NodeData(ctx, FactorNodeProperties(fform = sum))
1075+
model[c] = NodeData(ctx, FactorNodeProperties(fform = sum), 2)
10731076
@test !has_edge(model, a, c)
10741077
@test !has_edge(model, c, a)
10751078
add_edge!(model, c, getproperties(model[c]), a, :edge, 2)
@@ -1109,8 +1112,8 @@ end
11091112

11101113
a = NodeLabel(:a, 1)
11111114
b = NodeLabel(:b, 2)
1112-
model[a] = NodeData(ctx, FactorNodeProperties(fform = sum))
1113-
model[b] = NodeData(ctx, VariableNodeProperties(name = :b, index = nothing))
1115+
model[a] = NodeData(ctx, FactorNodeProperties(fform = sum), 1)
1116+
model[b] = NodeData(ctx, VariableNodeProperties(name = :b, index = nothing), 2)
11141117
add_edge!(model, a, getproperties(model[a]), b, :edge, 1)
11151118
@test collect(neighbors(model, NodeLabel(:a, 1))) == [NodeLabel(:b, 2)]
11161119

@@ -1120,9 +1123,9 @@ end
11201123
b = ResizableArray(NodeLabel, Val(1))
11211124
for i in 1:3
11221125
a[i] = NodeLabel(:a, i)
1123-
model[a[i]] = NodeData(ctx, FactorNodeProperties(fform = sum))
1126+
model[a[i]] = NodeData(ctx, FactorNodeProperties(fform = sum), i)
11241127
b[i] = NodeLabel(:b, i)
1125-
model[b[i]] = NodeData(ctx, VariableNodeProperties(name = :b, index = i))
1128+
model[b[i]] = NodeData(ctx, VariableNodeProperties(name = :b, index = i), i)
11261129
add_edge!(model, a[i], getproperties(model[a[i]]), b[i], :edge, i)
11271130
end
11281131
for n in b

test/plugins/variational_constraints/variational_constraints_engine_tests.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -862,35 +862,35 @@ end
862862
])
863863

864864
variable = ResolvedIndexedVariable(:w, 2:3, context)
865-
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2))
865+
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2), 2)
866866
@test node_data variable
867867

868868
variable = ResolvedIndexedVariable(:w, 2:3, context)
869-
node_data = GraphPPL.NodeData(GraphPPL.Context(), VariableNodeProperties(name = :w, index = 2))
869+
node_data = GraphPPL.NodeData(GraphPPL.Context(), VariableNodeProperties(name = :w, index = 2), 2)
870870
@test !(node_data variable)
871871

872872
variable = ResolvedIndexedVariable(:w, 2, context)
873-
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2))
873+
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2), 2)
874874
@test node_data variable
875875

876876
variable = ResolvedIndexedVariable(:w, SplittedRange(2, 3), context)
877-
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2))
877+
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2), 2)
878878
@test node_data variable
879879

880880
variable = ResolvedIndexedVariable(:w, SplittedRange(10, 15), context)
881-
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2))
881+
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2), 2)
882882
@test !(node_data variable)
883883

884884
variable = ResolvedIndexedVariable(:x, nothing, context)
885-
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :x, index = 2))
885+
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :x, index = 2), 2)
886886
@test node_data variable
887887

888888
variable = ResolvedIndexedVariable(:x, nothing, context)
889-
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :x, index = nothing))
889+
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :x, index = nothing), 1)
890890
@test node_data variable
891891

892892
variable = ResolvedIndexedVariable(:prec, 3, context)
893-
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :prec, index = (1, 3)))
893+
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :prec, index = (1, 3)), 2)
894894
@test node_data variable
895895
end
896896

0 commit comments

Comments
 (0)