Skip to content

Commit bd5f843

Browse files
authored
Merge pull request #243 from ReactiveBayes/equals
Throw warning if we compare `VariableRef` to anything other than `VariableRef`
2 parents 5fdf96e + ff106fa commit bd5f843

File tree

4 files changed

+161
-1
lines changed

4 files changed

+161
-1
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,5 @@ benchmark_*md
5353
examples/*Compiled
5454
statprof
5555
profile.pb.gz
56-
.swp
56+
.swp
57+
*.info

src/graph_engine.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,25 @@ struct VariableRef{M, C, O, I, E, L}
879879
internal_collection::L
880880
end
881881

882+
Base.:(==)(left::VariableRef, right::VariableRef) =
883+
left.model == right.model && left.context == right.context && left.name == right.name && left.index == right.index
884+
885+
function Base.:(==)(left::VariableRef, right)
886+
error(
887+
"Comparing Factor Graph variable `$left` with a value. This is not possible as the value of `$left` is not known at model construction time."
888+
)
889+
end
890+
Base.:(==)(left, right::VariableRef) = right == left
891+
892+
Base.:(>)(left::VariableRef, right) = left == right
893+
Base.:(>)(left, right::VariableRef) = left == right
894+
Base.:(<)(left::VariableRef, right) = left == right
895+
Base.:(<)(left, right::VariableRef) = left == right
896+
Base.:(>=)(left::VariableRef, right) = left == right
897+
Base.:(>=)(left, right::VariableRef) = left == right
898+
Base.:(<=)(left::VariableRef, right) = left == right
899+
Base.:(<=)(left, right::VariableRef) = left == right
900+
882901
is_proxied(::Type{T}) where {T <: VariableRef} = True()
883902

884903
external_collection_typeof(::Type{VariableRef{M, C, O, I, E, L}}) where {M, C, O, I, E, L} = E

test/graph_construction_tests.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1715,3 +1715,59 @@ end
17151715
@test length(collect(filter(as_variable(:in), model))) == 3
17161716
@test length(collect(filter(as_variable(:out), model))) == 2
17171717
end
1718+
1719+
@testitem "Comparing variables throws warning" begin
1720+
import GraphPPL: create_model, getorcreate!
1721+
1722+
include("testutils.jl")
1723+
@model function test_model(y)
1724+
x ~ Normal(0.0, 1.0)
1725+
if x == 0
1726+
z ~ Normal(0.0, 1.0)
1727+
else
1728+
z ~ Normal(1.0, 1.0)
1729+
end
1730+
end
1731+
1732+
@test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." create_model(
1733+
test_model(y = 1)
1734+
)
1735+
1736+
@model function test_model(y)
1737+
x ~ Normal(0.0, 1.0)
1738+
if x > 0
1739+
z ~ Normal(0.0, 1.0)
1740+
else
1741+
z ~ Normal(1.0, 1.0)
1742+
end
1743+
end
1744+
1745+
@test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." create_model(
1746+
test_model(y = 1)
1747+
)
1748+
@model function test_model(y)
1749+
x ~ Normal(0.0, 1.0)
1750+
if x < 0
1751+
z ~ Normal(0.0, 1.0)
1752+
else
1753+
z ~ Normal(1.0, 1.0)
1754+
end
1755+
end
1756+
1757+
@test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." create_model(
1758+
test_model(y = 1)
1759+
)
1760+
1761+
@model function test_model(y)
1762+
x ~ Normal(0.0, 1.0)
1763+
if 0 >= x
1764+
z ~ Normal(0.0, 1.0)
1765+
else
1766+
z ~ Normal(1.0, 1.0)
1767+
end
1768+
end
1769+
1770+
@test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." create_model(
1771+
test_model(y = 1)
1772+
)
1773+
end

test/graph_engine_tests.jl

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,90 @@ end
814814
end
815815
end
816816

817+
@testitem "`VariableRef` comparison" begin
818+
import GraphPPL:
819+
VariableRef,
820+
makevarref,
821+
getcontext,
822+
getifcreated,
823+
unroll,
824+
ProxyLabel,
825+
NodeLabel,
826+
proxylabel,
827+
NodeCreationOptions,
828+
VariableKindRandom,
829+
VariableKindData,
830+
getproperties,
831+
is_kind,
832+
MissingCollection,
833+
getorcreate!
834+
835+
using Distributions
836+
837+
include("testutils.jl")
838+
839+
model = create_test_model()
840+
ctx = getcontext(model)
841+
xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,))
842+
@test xref == xref
843+
@test_throws(
844+
"Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time.",
845+
xref != 1
846+
)
847+
@test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." 1 !=
848+
xref
849+
@test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." xref ==
850+
1
851+
@test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." 1 ==
852+
xref
853+
@test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." xref >
854+
0
855+
@test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." 0 <
856+
xref
857+
@test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." "something" ==
858+
xref
859+
@test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." 10 >
860+
xref
861+
@test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." xref <
862+
10
863+
@test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." 0 <=
864+
xref
865+
@test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." xref >=
866+
0
867+
@test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." xref <=
868+
0
869+
@test_throws "Comparing Factor Graph variable `x` with a value. This is not possible as the value of `x` is not known at model construction time." 0 >=
870+
xref
871+
872+
xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (1, 2))
873+
@test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." xref !=
874+
1
875+
@test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." 1 !=
876+
xref
877+
@test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." xref ==
878+
1
879+
@test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." 1 ==
880+
xref
881+
@test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." xref >
882+
0
883+
@test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." 0 <
884+
xref
885+
@test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." "something" ==
886+
xref
887+
@test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." 10 >
888+
xref
889+
@test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." xref <
890+
10
891+
@test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." 0 <=
892+
xref
893+
@test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." xref >=
894+
0
895+
@test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." xref <=
896+
0
897+
@test_throws "Comparing Factor Graph variable `x[1,2]` with a value. This is not possible as the value of `x[1,2]` is not known at model construction time." 0 >=
898+
xref
899+
end
900+
817901
@testitem "NodeLabel properties" begin
818902
import GraphPPL: NodeLabel
819903

0 commit comments

Comments
 (0)