Skip to content

Commit 8d66ed7

Browse files
authored
Merge pull request #247 from ReactiveBayes/type-stability
Better type stability
2 parents 2472420 + f070122 commit 8d66ed7

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

src/graph_engine.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -191,16 +191,16 @@ NodeBehaviour(backend, fform) = error("Backend $backend must implement a method
191191
192192
A unique identifier for a factor node in a probabilistic graphical model.
193193
"""
194-
mutable struct FactorID
195-
const fform::Any
194+
mutable struct FactorID{F}
195+
const fform::F
196196
const index::Int64
197197
end
198198

199199
fform(id::FactorID) = id.fform
200200
index(id::FactorID) = id.index
201201

202202
Base.show(io::IO, id::FactorID) = print(io, "(", fform(id), ", ", index(id), ")")
203-
Base.:(==)(id1::FactorID, id2::FactorID) = id1.fform == id2.fform && id1.index == id2.index
203+
Base.:(==)(id1::FactorID{F}, id2::FactorID{T}) where {F, T} = id1.fform == id2.fform && id1.index == id2.index
204204
Base.hash(id::FactorID, h::UInt) = hash(id.fform, hash(id.index, h))
205205

206206
"""
@@ -258,11 +258,12 @@ getname(labels::ResizableArray{T, V, N} where {T <: NodeLabel, V, N}) = getname(
258258
iterate(label::NodeLabel) = (label, nothing)
259259
iterate(label::NodeLabel, any) = nothing
260260

261-
to_symbol(label::NodeLabel) = Symbol(String(label.name) * "_" * string(label.global_counter))
261+
to_symbol(label::NodeLabel) = to_symbol(label.name, label.global_counter)
262+
to_symbol(name::Any, index::Int) = Symbol(string(name, "_", index))
262263

263264
Base.show(io::IO, label::NodeLabel) = print(io, label.name, "_", label.global_counter)
264265
Base.:(==)(label1::NodeLabel, label2::NodeLabel) = label1.name == label2.name && label1.global_counter == label2.global_counter
265-
Base.hash(label::NodeLabel, h::UInt) = hash(label.name, hash(label.global_counter, h))
266+
Base.hash(label::NodeLabel, h::UInt) = hash(label.global_counter, h)
266267

267268
"""
268269
EdgeLabel(symbol, index)
@@ -1481,7 +1482,8 @@ end
14811482

14821483
function add_constant_node!(model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index)
14831484
label = __add_variable_node!(model, context, options, name, index)
1484-
context[to_symbol(label), index] = label
1485+
context[to_symbol(name, label.global_counter), index] = label # to_symbol(label) is type unstable and we know the type of label.name here from name
1486+
return label
14851487
end
14861488

14871489
function __add_variable_node!(model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index)
@@ -2130,9 +2132,9 @@ Calls a plugin specific logic after the model has been created. By default does
21302132
"""
21312133
postprocess_plugin(plugin, model) = nothing
21322134

2133-
function preprocess_plugins(type::AbstractPluginTraitType, model::Model, context::Context, label, nodedata, options)
2135+
function preprocess_plugins(type::AbstractPluginTraitType, model::Model, context::Context, label::NodeLabel, nodedata::NodeData, options)::Tuple{NodeLabel, NodeData}
21342136
plugins = filter(type, getplugins(model))
21352137
return foldl(plugins; init = (label, nodedata)) do (label, nodedata), plugin
2136-
return preprocess_plugin(plugin, model, context, label, nodedata, options)
2137-
end
2138+
return preprocess_plugin(plugin, model, context, label, nodedata, options)::Tuple{NodeLabel, NodeData}
2139+
end::Tuple{NodeLabel, NodeData}
21382140
end

0 commit comments

Comments
 (0)