Skip to content

Commit 6f99b13

Browse files
authored
Merge pull request #264 from ReactiveBayes/fix262
Fix issue 262
2 parents ebc979a + 832809d commit 6f99b13

File tree

4 files changed

+66
-6
lines changed

4 files changed

+66
-6
lines changed

src/graph_engine.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,12 @@ end
341341
getname(label::ProxyLabel) = label.name
342342
index(label::ProxyLabel) = label.index
343343

344+
# This function allows to overwrite the `maycreate` flag on a proxy label, might be useful for situations where code should
345+
# definitely not create a new variable, e.g in the variational constraints plugin
346+
set_maycreate(proxylabel::ProxyLabel, maycreate::Union{True, False}) =
347+
ProxyLabel(proxylabel.name, proxylabel.proxied, proxylabel.index, maycreate)
348+
set_maycreate(something, maycreate::Union{True, False}) = something
349+
344350
function unroll(something)
345351
return something
346352
end

src/plugins/variational_constraints/variational_constraints_engine.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -754,21 +754,25 @@ function resolve(model::Model, context::Context, variable::IndexedVariable{<:Spl
754754
)
755755
end
756756

757+
# The variational constraints plugin should not attempt to create new variables in the model
758+
# Even if the `maycreate` flag was set to `True` at this point we assume that the variable has been created already
759+
unroll_nocreate(something) = unroll(set_maycreate(something, False()))
760+
757761
function resolve(model::Model, context::Context, variable::IndexedVariable{Nothing})
758-
global_label = unroll(context[getname(variable)])
762+
global_label = unroll_nocreate(context[getname(variable)])
759763
return __resolve(model, global_label)
760764
end
761765

762766
function resolve(model::Model, context::Context, variable::IndexedVariable)
763-
global_label = unroll(context[getname(variable)])[index(variable)...]
767+
global_label = unroll_nocreate(context[getname(variable)])[index(variable)...]
764768
return __resolve(model, global_label)
765769
end
766770

767771
resolve(model::Model, context::Context, variable::IndexedVariable{CombinedRange{NTuple{N, Int}, NTuple{N, Int}}}) where {N} =
768772
throw(UnresolvableFactorizationConstraintError("Cannot resolve factorization constraint for a combined range of dimension > 2."))
769773

770774
function resolve(model::Model, context::Context, variable::IndexedVariable{<:CombinedRange})
771-
global_label = view(unroll(context[getname(variable)]), firstindex(index(variable)):lastindex(index(variable)))
775+
global_label = view(unroll_nocreate(context[getname(variable)]), firstindex(index(variable)):lastindex(index(variable)))
772776
return __resolve(model, global_label)
773777
end
774778

@@ -803,7 +807,8 @@ function resolve(model::Model, context::Context, constraint::FactorizationConstr
803807
end
804808
lhs = map(variable -> resolve(model, context, variable), vfiltered)
805809
rhs = map(
806-
variable -> ResolvedFactorizationConstraintEntry((resolve(model, context, unroll(context[getname(variable)]), MeanField()),)),
810+
variable ->
811+
ResolvedFactorizationConstraintEntry((resolve(model, context, unroll_nocreate(context[getname(variable)]), MeanField()),)),
807812
vfiltered
808813
)
809814
return ResolvedFactorizationConstraint(ResolvedConstraintLHS(lhs), rhs)
@@ -930,7 +935,7 @@ end
930935
function apply_constraints!(
931936
model::Model, context::Context, marginal_constraint::MarginalFormConstraint{T, F} where {T <: IndexedVariable, F}
932937
)
933-
applicable_nodes = unroll(context[getvariables(marginal_constraint)])
938+
applicable_nodes = unroll_nocreate(context[getvariables(marginal_constraint)])
934939
for node in applicable_nodes
935940
if hasextra(model[node], VariationalConstraintsMarginalFormConstraintKey)
936941
@warn lazy"Node $node already has functional form constraint $(opt[:q]) applied, therefore $constraint_data will not be applied"
@@ -945,7 +950,7 @@ function apply_constraints!(model::Model, context::Context, marginal_constraint:
945950
end
946951

947952
function apply_constraints!(model::Model, context::Context, message_constraint::MessageFormConstraint)
948-
applicable_nodes = unroll(context[getvariables(message_constraint)])
953+
applicable_nodes = unroll_nocreate(context[getvariables(message_constraint)])
949954
for node in applicable_nodes
950955
if hasextra(model[node], VariationalConstraintsMessagesFormConstraintKey)
951956
@warn lazy"Node $node already has functional form constraint $(opt[:q]) applied, therefore $constraint_data will not be applied"

test/graph_engine_tests.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,7 @@ end
712712
getcontext,
713713
getifcreated,
714714
unroll,
715+
set_maycreate,
715716
ProxyLabel,
716717
NodeLabel,
717718
proxylabel,
@@ -812,6 +813,27 @@ end
812813
# `x` can be inferred properly
813814
@test ctx[:x] === @inferred(unroll(proxylabel(:x, xref, nothing, False())))
814815
end
816+
817+
@testset "It should be possible to toggle `maycreate` flag" begin
818+
model = create_test_model()
819+
ctx = getcontext(model)
820+
xref = VariableRef(model, ctx, NodeCreationOptions(), :x, (nothing,))
821+
# The first time should throw since the variable has not been instantiated yet
822+
@test_throws "The variable `x` has been used, but has not been instantiated." unroll(proxylabel(:x, xref, nothing, False()))
823+
# Even though the `maycreate` flag is set to `True`, the `set_maycreate` should overwrite it with `False`
824+
@test_throws "The variable `x` has been used, but has not been instantiated." unroll(
825+
set_maycreate(proxylabel(:x, xref, nothing, True()), False())
826+
)
827+
828+
# Even though the `maycreate` flag is set to `False`, the `set_maycreate` should overwrite it with `True`
829+
@test unroll(set_maycreate(proxylabel(:x, xref, nothing, False()), True())) === ctx[:x]
830+
# At this point the variable should be created
831+
@test unroll(proxylabel(:x, xref, nothing, False())) === ctx[:x]
832+
@test unroll(proxylabel(:x, xref, nothing, True())) === ctx[:x]
833+
834+
@test set_maycreate(1, True()) === 1
835+
@test set_maycreate(1, False()) === 1
836+
end
815837
end
816838

817839
@testitem "`VariableRef` comparison" begin

test/plugins/variational_constraints/variational_constraints_tests.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,3 +1107,30 @@ end
11071107
end
11081108
end
11091109
end
1110+
1111+
@testitem "Issue 262, factorization constraint should not attempt to create a variable from submodels" begin
1112+
import GraphPPL: create_model, with_plugins, getproperties, neighbor_data, is_factorized
1113+
1114+
include("../../testutils.jl")
1115+
1116+
@model function submodel(y, x)
1117+
for i in 1:10
1118+
y[i] ~ Normal(x, 1)
1119+
end
1120+
end
1121+
1122+
@model function main_model()
1123+
x ~ Normal(0, 1)
1124+
y ~ submodel(x = x)
1125+
end
1126+
1127+
constraints = @constraints begin
1128+
for q in submodel
1129+
q(x, y) = q(x)q(y)
1130+
end
1131+
end
1132+
1133+
model = create_model(with_plugins(main_model(), GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin(constraints))))
1134+
1135+
@test length(collect(filter(as_node(Normal), model))) == 11
1136+
end

0 commit comments

Comments
 (0)