Skip to content

Commit 861934a

Browse files
authored
Merge pull request #238 from ReactiveBayes/const-creates-twice
Constants increase model counter twice
2 parents c112d1f + dd96ff6 commit 861934a

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

src/graph_engine.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1436,7 +1436,7 @@ getifcreated(model::Model, context::Context, var::Union{Tuple, AbstractArray{T}}
14361436
map((v) -> getifcreated(model, context, v), var)
14371437
getifcreated(model::Model, context::Context, var::ProxyLabel) = var
14381438
getifcreated(model::Model, context::Context, var) =
1439-
add_variable_node!(model, context, NodeCreationOptions(value = var, kind = :constant), gensym(model, :constvar), nothing)
1439+
add_constant_node!(model, context, NodeCreationOptions(value = var, kind = :constant), :constvar, nothing)
14401440

14411441
"""
14421442
add_variable_node!(model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index)
@@ -1461,14 +1461,22 @@ function add_variable_node!(model::Model, context::Context, name::Symbol, index)
14611461
end
14621462

14631463
function add_variable_node!(model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index)
1464+
label = __add_variable_node!(model, context, options, name, index)
1465+
context[name, index] = label
1466+
end
14641467

1468+
function add_constant_node!(model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index)
1469+
label = __add_variable_node!(model, context, options, name, index)
1470+
context[to_symbol(label), index] = label
1471+
end
1472+
1473+
function __add_variable_node!(model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index)
14651474
# In theory plugins are able to overwrite this
14661475
potential_label = generate_nodelabel(model, name)
14671476
potential_nodedata = NodeData(context, convert(VariableNodeProperties, name, index, options))
14681477
label, nodedata = preprocess_plugins(
14691478
UnionPluginType(VariableNodePlugin(), FactorAndVariableNodesPlugin()), model, context, potential_label, potential_nodedata, options
14701479
)
1471-
context[name, index] = label
14721480
add_vertex!(model, label, nodedata)
14731481
return label
14741482
end

test/graph_construction_tests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1678,7 +1678,6 @@ end
16781678
include("testutils.jl")
16791679

16801680
@model function neural_dot(out, in, w)
1681-
local c
16821681
c[1] ~ in[1] * w[1]
16831682
for i in 2:length(in)
16841683
c[i] ~ c[i - 1] + in[i] * w[i]

0 commit comments

Comments
 (0)