Skip to content

Commit 4f8613b

Browse files
authored
Merge pull request #227 from ReactiveBayes/duplicate_args
Bug fix for duplicate arguments
2 parents c9a744a + 5b1ae25 commit 4f8613b

File tree

3 files changed

+143
-12
lines changed

3 files changed

+143
-12
lines changed

src/graph_engine.jl

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -949,6 +949,26 @@ Returns a collection of aliases for `fform` depending on the `backend`.
949949
aliases(backend, fform) = error("Backend $backend must implement a method for `aliases` for `$(fform)`.")
950950
aliases(model::Model, fform::F) where {F} = aliases(getbackend(model), fform)
951951

952+
function add_vertex!(model::Model, label, data)
953+
# This is an unsafe procedure that implements behaviour from `MetaGraphsNext`.
954+
code = nv(model) + 1
955+
model.graph.vertex_labels[code] = label
956+
model.graph.vertex_properties[label] = (code, data)
957+
Graphs.add_vertex!(model.graph.graph)
958+
end
959+
960+
function add_edge!(model::Model, src, dst, data)
961+
# This is an unsafe procedure that implements behaviour from `MetaGraphsNext`.
962+
code_src, code_dst = MetaGraphsNext.code_for(model.graph, src), MetaGraphsNext.code_for(model.graph, dst)
963+
model.graph.edge_data[(src, dst)] = data
964+
return Graphs.add_edge!(model.graph.graph, code_src, code_dst)
965+
end
966+
967+
function has_edge(model::Model, src, dst)
968+
code_src, code_dst = MetaGraphsNext.code_for(model.graph, src), MetaGraphsNext.code_for(model.graph, dst)
969+
return Graphs.has_edge(model.graph.graph, code_src, code_dst)
970+
end
971+
952972
"""
953973
copy_markov_blanket_to_child_context(child_context::Context, interfaces::NamedTuple)
954974
@@ -1307,10 +1327,8 @@ function add_variable_node!(model::Model, context::Context, options::NodeCreatio
13071327
label, nodedata = preprocess_plugins(
13081328
UnionPluginType(VariableNodePlugin(), FactorAndVariableNodesPlugin()), model, context, potential_label, potential_nodedata, options
13091329
)
1310-
13111330
context[name, index] = label
1312-
model[label] = nodedata
1313-
1331+
add_vertex!(model, label, nodedata)
13141332
return label
13151333
end
13161334

@@ -1428,7 +1446,7 @@ function add_atomic_factor_node!(model::Model, context::Context, options::NodeCr
14281446
UnionPluginType(FactorNodePlugin(), FactorAndVariableNodesPlugin()), model, context, potential_label, potential_nodedata, options
14291447
)
14301448

1431-
model[label] = nodedata
1449+
add_vertex!(model, label, nodedata)
14321450
context[factornode_id] = label
14331451

14341452
return label, nodedata, convert(FactorNodeProperties, getproperties(nodedata))
@@ -1500,7 +1518,18 @@ function add_edge!(
15001518
label = EdgeLabel(interface_name, index)
15011519
neighbor_node_label = unroll(variable_node_id)
15021520
addneighbor!(factor_node_propeties, neighbor_node_label, label, model[neighbor_node_label])
1503-
model.graph[unroll(variable_node_id), factor_node_id] = label
1521+
edge_added = add_edge!(model, neighbor_node_label, factor_node_id, label)
1522+
if !edge_added
1523+
# Double check if the edge has already been added
1524+
if has_edge(model, neighbor_node_label, factor_node_id)
1525+
error(
1526+
lazy"Trying to create duplicate edge $(label) between variable $(neighbor_node_label) and factor node $(factor_node_id). Make sure that all the arguments to the `~` operator are unique (both left hand side and right hand side)."
1527+
)
1528+
else
1529+
error(lazy"Cannot create an edge $(label) between variable $(neighbor_node_label) and factor node $(factor_node_id).")
1530+
end
1531+
end
1532+
return label
15041533
end
15051534

15061535
function add_edge!(

test/graph_construction_tests.jl

Lines changed: 97 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,21 +1500,113 @@ end
15001500
end
15011501
end
15021502

1503-
@testitem "LazyIndex should support empty indices if array is passed" begin
1503+
@testitem "LazyIndex should support empty indices if array is passed" begin
15041504
import GraphPPL: create_model, getorcreate!, NodeCreationOptions, LazyIndex
15051505

15061506
include("testutils.jl")
15071507

1508-
@model function foo(y)
1508+
@model function foo(y)
15091509
x ~ MvNormal([1, 1], [1 0.0; 0.0 1.0])
15101510
y ~ MvNormal(x, [1.0 0.0; 0.0 1.0])
15111511
end
15121512

1513-
model = create_model(foo()) do model, ctx
1514-
return (; y = getorcreate!(model, ctx, NodeCreationOptions(kind = :data, factorized = true), :y, LazyIndex([ 1.0, 1.0 ])))
1513+
model = create_model(foo()) do model, ctx
1514+
return (; y = getorcreate!(model, ctx, NodeCreationOptions(kind = :data, factorized = true), :y, LazyIndex([1.0, 1.0])))
15151515
end
15161516

15171517
@test length(collect(filter(as_node(MvNormal), model))) == 2
15181518
@test length(collect(filter(as_variable(:x), model))) == 1
15191519
@test length(collect(filter(as_variable(:y), model))) == 1
1520-
end
1520+
end
1521+
1522+
@testitem "Node arguments must be unique" begin
1523+
import GraphPPL: create_model, getorcreate!, NodeCreationOptions, LazyIndex
1524+
1525+
include("testutils.jl")
1526+
1527+
@model function simple_model_duplicate_1()
1528+
x ~ Normal(0.0, 1.0)
1529+
y ~ x + x
1530+
end
1531+
1532+
@model function simple_model_duplicate_2()
1533+
x ~ Normal(0.0, 1.0)
1534+
y ~ x + x + x
1535+
end
1536+
1537+
@model function simple_model_duplicate_3()
1538+
x ~ Normal(0.0, 1.0)
1539+
y ~ Normal(x, x)
1540+
end
1541+
1542+
@model function simple_model_duplicate_4()
1543+
x ~ Normal(0.0, 1.0)
1544+
hide_x = x
1545+
y ~ Normal(hide_x, x)
1546+
end
1547+
1548+
@model function simple_model_duplicate_5()
1549+
x ~ Normal(0.0, 1.0)
1550+
x ~ Normal(x, 1)
1551+
end
1552+
1553+
@model function simple_model_duplicate_6()
1554+
x ~ Normal(0.0, 1.0)
1555+
hide_x = x
1556+
hide_x ~ Normal(x, 1)
1557+
end
1558+
1559+
for modelfn in [
1560+
simple_model_duplicate_1,
1561+
simple_model_duplicate_2,
1562+
simple_model_duplicate_3,
1563+
simple_model_duplicate_4,
1564+
simple_model_duplicate_5,
1565+
simple_model_duplicate_6
1566+
]
1567+
@test_throws r"Trying to create duplicate edge.*Make sure that all the arguments to the `~` operator are unique.*" create_model(
1568+
modelfn()
1569+
)
1570+
end
1571+
1572+
@model function my_model(obs, N, sigma)
1573+
local x
1574+
for i in 1:N
1575+
x[i] ~ Bernoulli(0.5)
1576+
end
1577+
local C
1578+
# This model creation is not allowed since `C` is used twice in the `~` operator
1579+
for i in 1:N
1580+
C ~ C + x[i]
1581+
end
1582+
obs ~ NormalMeanVariance(C, sigma^2)
1583+
end
1584+
1585+
@test_throws r"Trying to create duplicate edge.*Make sure that all the arguments to the `~` operator are unique.*" create_model(
1586+
my_model(N = 3, sigma = 1.0)
1587+
) do model, ctx
1588+
obs = getorcreate!(model, ctx, NodeCreationOptions(kind = :data, factorized = true), :obs, LazyIndex(0.0))
1589+
return (obs = obs,)
1590+
end
1591+
1592+
@model function my_model(obs, N, sigma)
1593+
local x
1594+
for i in 1:N
1595+
x[i] ~ Bernoulli(0.5)
1596+
end
1597+
accum_C = x[1]
1598+
for i in 2:N
1599+
# Here `next_C` will be used twice on the second iteration
1600+
next_C ~ accum_C + x[i]
1601+
accum_C = next_C
1602+
end
1603+
obs ~ NormalMeanVariance(accum_C, sigma^2)
1604+
end
1605+
1606+
@test_throws r"Trying to create duplicate edge.*Make sure that all the arguments to the `~` operator are unique.*" create_model(
1607+
my_model(N = 3, sigma = 1.0)
1608+
) do model, ctx
1609+
obs = getorcreate!(model, ctx, NodeCreationOptions(kind = :data, factorized = true), :obs, LazyIndex(0.0))
1610+
return (obs = obs,)
1611+
end
1612+
end

test/graph_engine_tests.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,7 @@ end
759759
EdgeLabel,
760760
getname,
761761
add_edge!,
762+
has_edge,
762763
getproperties
763764

764765
include("testutils.jl")
@@ -770,12 +771,21 @@ end
770771
b = NodeLabel(:b, 2)
771772
model[a] = NodeData(ctx, VariableNodeProperties(name = :a, index = nothing))
772773
model[b] = NodeData(ctx, FactorNodeProperties(fform = sum))
774+
@test !has_edge(model, a, b)
775+
@test !has_edge(model, b, a)
773776
add_edge!(model, b, getproperties(model[b]), a, :edge, 1)
777+
@test has_edge(model, a, b)
778+
@test has_edge(model, b, a)
774779
@test length(edges(model)) == 1
775780

776781
c = NodeLabel(:c, 2)
777782
model[c] = NodeData(ctx, FactorNodeProperties(fform = sum))
783+
@test !has_edge(model, a, c)
784+
@test !has_edge(model, c, a)
778785
add_edge!(model, c, getproperties(model[c]), a, :edge, 2)
786+
@test has_edge(model, a, c)
787+
@test has_edge(model, c, a)
788+
779789
@test length(edges(model)) == 2
780790

781791
# Test 2: Test getting all edges from a model with a specific node
@@ -1842,13 +1852,13 @@ end
18421852
model = create_test_model()
18431853
ctx = getcontext(model)
18441854
options = NodeCreationOptions()
1845-
x, xdata, xproperties = GraphPPL.add_atomic_factor_node!(model, ctx, options, sum)
18461855
y = getorcreate!(model, ctx, :y, nothing)
18471856

18481857
variable_nodes = [getorcreate!(model, ctx, i, nothing) for i in [:a, :b, :c]]
1858+
x, xdata, xproperties = GraphPPL.add_atomic_factor_node!(model, ctx, options, sum)
18491859
add_edge!(model, x, xproperties, variable_nodes, :interface)
18501860

1851-
@test ne(model) == 3 && model[x, variable_nodes[1]] == EdgeLabel(:interface, 1)
1861+
@test ne(model) == 3 && model[variable_nodes[1], x] == EdgeLabel(:interface, 1)
18521862
end
18531863

18541864
@testitem "default_parametrization" begin

0 commit comments

Comments
 (0)