Skip to content

Commit 81aadda

Browse files
committed
param factorID
1 parent 8d3fc7c commit 81aadda

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

src/graph_engine.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,16 +191,16 @@ NodeBehaviour(backend, fform) = error("Backend $backend must implement a method
191191
192192
A unique identifier for a factor node in a probabilistic graphical model.
193193
"""
194-
mutable struct FactorID
195-
const fform::Any
194+
mutable struct FactorID{F}
195+
const fform::F
196196
const index::Int64
197197
end
198198

199199
fform(id::FactorID) = id.fform
200200
index(id::FactorID) = id.index
201201

202202
Base.show(io::IO, id::FactorID) = print(io, "(", fform(id), ", ", index(id), ")")
203-
Base.:(==)(id1::FactorID, id2::FactorID) = id1.fform == id2.fform && id1.index == id2.index
203+
Base.:(==)(id1::FactorID{F}, id2::FactorID{T}) where {F, T} = id1.fform == id2.fform && id1.index == id2.index
204204
Base.hash(id::FactorID, h::UInt) = hash(id.fform, hash(id.index, h))
205205

206206
"""
@@ -258,7 +258,8 @@ getname(labels::ResizableArray{T, V, N} where {T <: NodeLabel, V, N}) = getname(
258258
iterate(label::NodeLabel) = (label, nothing)
259259
iterate(label::NodeLabel, any) = nothing
260260

261-
to_symbol(label::NodeLabel) = Symbol(String(label.name) * "_" * string(label.global_counter))
261+
to_symbol(label::NodeLabel) = to_symbol(label.name, label.global_counter)
262+
to_symbol(name::D, index::Int) where {D} = Symbol(String(name) * "_" * string(index))
262263

263264
Base.show(io::IO, label::NodeLabel) = print(io, label.name, "_", label.global_counter)
264265
Base.:(==)(label1::NodeLabel, label2::NodeLabel) = label1.name == label2.name && label1.global_counter == label2.global_counter
@@ -1481,7 +1482,7 @@ end
14811482

14821483
function add_constant_node!(model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index)
14831484
label = __add_variable_node!(model, context, options, name, index)
1484-
context[Symbol(String(name) * "_" * string(label.global_counter)), index] = label
1485+
context[to_symbol(name, label.global_counter), index] = label # to_symbol(label) is type unstable and we know the type of label.name here from name
14851486
return label
14861487
end
14871488

0 commit comments

Comments
 (0)