Skip to content

Commit 0701b46

Browse files
committed
more tests to the god of tests
1 parent 44af1c0 commit 0701b46

File tree

3 files changed

+80
-17
lines changed

3 files changed

+80
-17
lines changed

src/graph_engine.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,11 @@ function add_edge!(model::Model, src, dst, data)
962962
return Graphs.add_edge!(model.graph.graph, code_src, code_dst)
963963
end
964964

965+
function has_edge(model::Model, src, dst)
966+
code_src, code_dst = MetaGraphsNext.code_for(model.graph, src), MetaGraphsNext.code_for(model.graph, dst)
967+
return Graphs.has_edge(model.graph.graph, code_src, code_dst)
968+
end
969+
965970
"""
966971
copy_markov_blanket_to_child_context(child_context::Context, interfaces::NamedTuple)
967972
@@ -1511,11 +1516,16 @@ function add_edge!(
15111516
label = EdgeLabel(interface_name, index)
15121517
neighbor_node_label = unroll(variable_node_id)
15131518
addneighbor!(factor_node_propeties, neighbor_node_label, label, model[neighbor_node_label])
1514-
edge_added = add_edge!(model, unroll(variable_node_id), factor_node_id, label)
1519+
edge_added = add_edge!(model, neighbor_node_label, factor_node_id, label)
15151520
if !edge_added
1516-
error(
1517-
lazy"Trying to create duplicate edge ($(unroll(variable_node_id)), $(factor_node_id)) while creating edge $(label) of factor node with functional form $(factor_node_propeties.fform)"
1518-
)
1521+
# Double check if the edge has already been added
1522+
if has_edge(model, neighbor_node_label, factor_node_id)
1523+
error(
1524+
lazy"Trying to create duplicate edge $(label) between variable $(neighbor_node_label) and factor node $(factor_node_id). Make sure that all the arguments to the `~` operator are unique (both left hand side and right hand side)."
1525+
)
1526+
else
1527+
error(lazy"Cannot create an edge $(label) between variable $(neighbor_node_label) and factor node $(factor_node_id).")
1528+
end
15191529
end
15201530
return label
15211531
end

test/graph_construction_tests.jl

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,46 +1524,89 @@ end
15241524

15251525
include("testutils.jl")
15261526

1527+
@model function simple_model_duplicate_1()
1528+
x ~ Normal(0.0, 1.0)
1529+
y ~ x + x
1530+
end
1531+
1532+
@model function simple_model_duplicate_2()
1533+
x ~ Normal(0.0, 1.0)
1534+
y ~ x + x + x
1535+
end
1536+
1537+
@model function simple_model_duplicate_3()
1538+
x ~ Normal(0.0, 1.0)
1539+
y ~ Normal(x, x)
1540+
end
1541+
1542+
@model function simple_model_duplicate_4()
1543+
x ~ Normal(0.0, 1.0)
1544+
hide_x = x
1545+
y ~ Normal(hide_x, x)
1546+
end
1547+
1548+
@model function simple_model_duplicate_5()
1549+
x ~ Normal(0.0, 1.0)
1550+
x ~ Normal(x, 1)
1551+
end
1552+
1553+
@model function simple_model_duplicate_6()
1554+
x ~ Normal(0.0, 1.0)
1555+
hide_x = x
1556+
hide_x ~ Normal(x, 1)
1557+
end
1558+
1559+
for modelfn in [
1560+
simple_model_duplicate_1,
1561+
simple_model_duplicate_2,
1562+
simple_model_duplicate_3,
1563+
simple_model_duplicate_4,
1564+
simple_model_duplicate_5,
1565+
simple_model_duplicate_6
1566+
]
1567+
@test_throws r"Trying to create duplicate edge.*Make sure that all the arguments to the `~` operator are unique.*" create_model(
1568+
modelfn()
1569+
)
1570+
end
1571+
15271572
@model function my_model(obs, N, sigma)
1528-
local p
1529-
for i in 1:N
1530-
p[i] ~ Beta(1, 1)
1531-
end
15321573
local x
15331574
for i in 1:N
1534-
x[i] ~ Bernoulli(p[i])
1575+
x[i] ~ Bernoulli(0.5)
15351576
end
15361577
local C
1578+
# This model creation is not allowed since `C` is used twice in the `~` operator
15371579
for i in 1:N
15381580
C ~ C + x[i]
15391581
end
15401582
obs ~ NormalMeanVariance(C, sigma^2)
15411583
end
15421584

1543-
@test_throws "Trying to create duplicate edge" create_model(my_model(N = 3, sigma = 1.0)) do model, ctx
1585+
@test_throws r"Trying to create duplicate edge.*Make sure that all the arguments to the `~` operator are unique.*" create_model(
1586+
my_model(N = 3, sigma = 1.0)
1587+
) do model, ctx
15441588
obs = getorcreate!(model, ctx, NodeCreationOptions(kind = :data, factorized = true), :obs, LazyIndex(0.0))
15451589
return (obs = obs,)
15461590
end
15471591

15481592
@model function my_model(obs, N, sigma)
1549-
local p
1550-
for i in 1:N
1551-
p[i] ~ Beta(1, 1)
1552-
end
15531593
local x
15541594
for i in 1:N
1555-
x[i] ~ Bernoulli(p[i])
1595+
x[i] ~ Bernoulli(0.5)
15561596
end
15571597
accum_C = x[1]
15581598
for i in 2:N
1599+
# Here `next_C` will be used twice on the second iteration
15591600
next_C ~ accum_C + x[i]
15601601
accum_C = next_C
15611602
end
15621603
obs ~ NormalMeanVariance(accum_C, sigma^2)
15631604
end
15641605

1565-
@test_throws "Trying to create duplicate edge" create_model(my_model(N = 3, sigma = 1.0)) do model, ctx
1606+
@test_throws r"Trying to create duplicate edge.*Make sure that all the arguments to the `~` operator are unique.*" create_model(
1607+
my_model(N = 3, sigma = 1.0)
1608+
) do model, ctx
15661609
obs = getorcreate!(model, ctx, NodeCreationOptions(kind = :data, factorized = true), :obs, LazyIndex(0.0))
15671610
return (obs = obs,)
15681611
end
1569-
end
1612+
end

test/graph_engine_tests.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,7 @@ end
759759
EdgeLabel,
760760
getname,
761761
add_edge!,
762+
has_edge,
762763
getproperties
763764

764765
include("testutils.jl")
@@ -770,12 +771,21 @@ end
770771
b = NodeLabel(:b, 2)
771772
model[a] = NodeData(ctx, VariableNodeProperties(name = :a, index = nothing))
772773
model[b] = NodeData(ctx, FactorNodeProperties(fform = sum))
774+
@test !has_edge(model, a, b)
775+
@test !has_edge(model, b, a)
773776
add_edge!(model, b, getproperties(model[b]), a, :edge, 1)
777+
@test has_edge(model, a, b)
778+
@test has_edge(model, b, a)
774779
@test length(edges(model)) == 1
775780

776781
c = NodeLabel(:c, 2)
777782
model[c] = NodeData(ctx, FactorNodeProperties(fform = sum))
783+
@test !has_edge(model, a, c)
784+
@test !has_edge(model, c, a)
778785
add_edge!(model, c, getproperties(model[c]), a, :edge, 2)
786+
@test has_edge(model, a, c)
787+
@test has_edge(model, c, a)
788+
779789
@test length(edges(model)) == 2
780790

781791
# Test 2: Test getting all edges from a model with a specific node

0 commit comments

Comments
 (0)