Skip to content

Commit ec8c476

Browse files
committed
Throw warning if we compare variableref to anything other than variableref
1 parent b3ae2a6 commit ec8c476

File tree

3 files changed

+159
-0
lines changed

3 files changed

+159
-0
lines changed

src/graph_engine.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,24 @@ struct VariableRef{M, C, O, I, E, L}
879879
internal_collection::L
880880
end
881881

882+
Base.:(==)(left::VariableRef, right::VariableRef) =
883+
left.model == right.model && left.context == right.context && left.name == right.name && left.index == right.index
884+
885+
function Base.:(==)(left::VariableRef, right)
886+
@warn "Comparing Factor Graph variable ($left) with a value. This is not possible as the value of $left is not known at model construction time."
887+
return false
888+
end
889+
Base.:(==)(left, right::VariableRef) = right == left
890+
891+
Base.:(>)(left::VariableRef, right) = left == right
892+
Base.:(>)(left, right::VariableRef) = left == right
893+
Base.:(<)(left::VariableRef, right) = left == right
894+
Base.:(<)(left, right::VariableRef) = left == right
895+
Base.:(>=)(left::VariableRef, right) = left == right
896+
Base.:(>=)(left, right::VariableRef) = left == right
897+
Base.:(<=)(left::VariableRef, right) = left == right
898+
Base.:(<=)(left, right::VariableRef) = left == right
899+
882900
is_proxied(::Type{T}) where {T <: VariableRef} = True()
883901

884902
external_collection_typeof(::Type{VariableRef{M, C, O, I, E, L}}) where {M, C, O, I, E, L} = E

test/graph_construction_tests.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1715,3 +1715,64 @@ end
17151715
@test length(collect(filter(as_variable(:in), model))) == 3
17161716
@test length(collect(filter(as_variable(:out), model))) == 2
17171717
end
1718+
1719+
@testitem "Comparing variables throws warning" begin
1720+
import GraphPPL: create_model, getorcreate!
1721+
1722+
include("testutils.jl")
1723+
@model function test_model(y)
1724+
x ~ Normal(0.0, 1.0)
1725+
if x == 0
1726+
z ~ Normal(0.0, 1.0)
1727+
else
1728+
z ~ Normal(1.0, 1.0)
1729+
end
1730+
end
1731+
1732+
@test_logs (
1733+
:warn,
1734+
"Comparing Factor Graph variable (x) with a value. This is not possible as the value of x is not known at model construction time."
1735+
) create_model(test_model(y = 1))
1736+
1737+
@model function test_model(y)
1738+
x ~ Normal(0.0, 1.0)
1739+
if x > 0
1740+
z ~ Normal(0.0, 1.0)
1741+
else
1742+
z ~ Normal(1.0, 1.0)
1743+
end
1744+
end
1745+
1746+
@test_logs (
1747+
:warn,
1748+
"Comparing Factor Graph variable (x) with a value. This is not possible as the value of x is not known at model construction time."
1749+
) create_model(test_model(y = 1))
1750+
1751+
@model function test_model(y)
1752+
x ~ Normal(0.0, 1.0)
1753+
if x < 0
1754+
z ~ Normal(0.0, 1.0)
1755+
else
1756+
z ~ Normal(1.0, 1.0)
1757+
end
1758+
end
1759+
1760+
@test_logs (
1761+
:warn,
1762+
"Comparing Factor Graph variable (x) with a value. This is not possible as the value of x is not known at model construction time."
1763+
) create_model(test_model(y = 1))
1764+
1765+
@model function test_model(y)
1766+
x ~ Normal(0.0, 1.0)
1767+
if 0 >= x
1768+
z ~ Normal(0.0, 1.0)
1769+
else
1770+
z ~ Normal(1.0, 1.0)
1771+
end
1772+
end
1773+
1774+
@test_logs (
1775+
:warn,
1776+
"Comparing Factor Graph variable (x) with a value. This is not possible as the value of x is not known at model construction time."
1777+
) create_model(test_model(y = 1))
1778+
end

test/graph_engine_tests.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,86 @@ end
814814
end
815815
end
816816

817+
@testitem "`VariableRef` comparison" begin
818+
import GraphPPL:
819+
VariableRef,
820+
makevarref,
821+
getcontext,
822+
getifcreated,
823+
unroll,
824+
ProxyLabel,
825+
NodeLabel,
826+
proxylabel,
827+
NodeCreationOptions,
828+
VariableKindRandom,
829+
VariableKindData,
830+
getproperties,
831+
is_kind,
832+
MissingCollection,
833+
getorcreate!
834+
835+
using Distributions
836+
837+
include("testutils.jl")
838+
839+
model = create_test_model()
840+
ctx = getcontext(model)
841+
xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,))
842+
@test xref == xref
843+
@test_logs (
844+
:warn,
845+
"Comparing Factor Graph variable (x) with a value. This is not possible as the value of x is not known at model construction time."
846+
) xref != 1
847+
@test_logs (
848+
:warn,
849+
"Comparing Factor Graph variable (x) with a value. This is not possible as the value of x is not known at model construction time."
850+
) 1 != xref
851+
@test_logs (
852+
:warn,
853+
"Comparing Factor Graph variable (x) with a value. This is not possible as the value of x is not known at model construction time."
854+
) xref == 1
855+
@test_logs (
856+
:warn,
857+
"Comparing Factor Graph variable (x) with a value. This is not possible as the value of x is not known at model construction time."
858+
) 1 == xref
859+
@test_logs (
860+
:warn,
861+
"Comparing Factor Graph variable (x) with a value. This is not possible as the value of x is not known at model construction time."
862+
) xref > 0
863+
@test_logs (
864+
:warn,
865+
"Comparing Factor Graph variable (x) with a value. This is not possible as the value of x is not known at model construction time."
866+
) 0 < xref
867+
@test_logs (
868+
:warn,
869+
"Comparing Factor Graph variable (x) with a value. This is not possible as the value of x is not known at model construction time."
870+
) "something" == xref
871+
@test_logs (
872+
:warn,
873+
"Comparing Factor Graph variable (x) with a value. This is not possible as the value of x is not known at model construction time."
874+
) 10 > xref
875+
@test_logs (
876+
:warn,
877+
"Comparing Factor Graph variable (x) with a value. This is not possible as the value of x is not known at model construction time."
878+
) xref < 10
879+
@test_logs (
880+
:warn,
881+
"Comparing Factor Graph variable (x) with a value. This is not possible as the value of x is not known at model construction time."
882+
) 0 <= xref
883+
@test_logs (
884+
:warn,
885+
"Comparing Factor Graph variable (x) with a value. This is not possible as the value of x is not known at model construction time."
886+
) xref >= 0
887+
@test_logs (
888+
:warn,
889+
"Comparing Factor Graph variable (x) with a value. This is not possible as the value of x is not known at model construction time."
890+
) xref <= 0
891+
@test_logs (
892+
:warn,
893+
"Comparing Factor Graph variable (x) with a value. This is not possible as the value of x is not known at model construction time."
894+
) 0 >= xref
895+
end
896+
817897
@testitem "NodeLabel properties" begin
818898
import GraphPPL: NodeLabel
819899

0 commit comments

Comments
 (0)