You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: src/Core.jl
+9-10Lines changed: 9 additions & 10 deletions
Original file line number
Diff line number
Diff line change
@@ -45,17 +45,17 @@ $(TYPEDEF)
45
45
Probabilistic modeling with a tensor network.
46
46
47
47
### Fields
48
-
* `vars` are the degrees of freedom in the tensor network.
48
+
* `nvars` are the number of variables in the tensor network.
49
49
* `code` is the tensor network contraction pattern.
50
50
* `tensors` are the tensors fed into the tensor network, the leading tensors are unity tensors associated with `unity_tensors_labels`.
51
51
* `evidence` is a dictionary used to specify degrees of freedom that are fixed to certain values.
52
52
* `unity_tensors_idx` is a vector of indices of the unity tensors in the `tensors` array. Unity tensors are dummy tensors used to obtain the marginal probabilities.
53
53
"""
54
-
struct TensorNetworkModel{LT, ET, MT <:AbstractArray}
55
-
vars::Vector{LT}
54
+
struct TensorNetworkModel{ET, MT <:AbstractArray}
55
+
nvars::Int
56
56
code::ET
57
57
tensors::Vector{MT}
58
-
evidence::Dict{LT, Int}
58
+
evidence::Dict{Int, Int}
59
59
unity_tensors_idx::Vector{Int}
60
60
end
61
61
@@ -78,7 +78,7 @@ end
78
78
79
79
function Base.show(io::IO, tn::TensorNetworkModel)
80
80
open =getiyv(tn.code)
81
-
variables =join([string_var(var, open, tn.evidence) for var intn.vars], ", ")
81
+
variables =join([string_var(var, open, tn.evidence) for var inget_vars(tn)], ", ")
82
82
tc, sc, rw =contraction_complexity(tn)
83
83
println(io, "$(typeof(tn))")
84
84
println(io, "variables: $variables")
@@ -128,25 +128,24 @@ function TensorNetworkModel(
128
128
tensors = Array{ET}[[ones(ET, [model.cards[i] for i in lb]...) for lb in unity_tensors_labels]..., [t.vals for t in model.factors]...]
Copy file name to clipboardExpand all lines: src/cspmodels.jl
+1-1Lines changed: 1 addition & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -50,7 +50,7 @@ function update_temperature(tnet::TensorNetworkModel, problem::ConstraintSatisfa
50
50
tensors, ixs =generate_tensors(β, problem)
51
51
@assert tnet.unity_tensors_idx ==collect(1:length(tnet.unity_tensors_idx)) "The target tensor network can not be updated! Got `unity_tensors_idx = $(tnet.unity_tensors_idx)`"
0 commit comments