Skip to content

Commit c57bc8a

Browse files
committed
vars -> nvars
1 parent 563950b commit c57bc8a

File tree

6 files changed

+13
-17
lines changed

6 files changed

+13
-17
lines changed

src/Core.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,17 @@ $(TYPEDEF)
4545
Probabilistic modeling with a tensor network.
4646
4747
### Fields
48-
* `vars` are the degrees of freedom in the tensor network.
48+
* `nvars` are the number of variables in the tensor network.
4949
* `code` is the tensor network contraction pattern.
5050
* `tensors` are the tensors fed into the tensor network, the leading tensors are unity tensors associated with `unity_tensors_labels`.
5151
* `evidence` is a dictionary used to specify degrees of freedom that are fixed to certain values.
5252
* `unity_tensors_idx` is a vector of indices of the unity tensors in the `tensors` array. Unity tensors are dummy tensors used to obtain the marginal probabilities.
5353
"""
54-
struct TensorNetworkModel{LT, ET, MT <: AbstractArray}
55-
vars::Vector{LT}
54+
struct TensorNetworkModel{ET, MT <: AbstractArray}
55+
nvars::Int
5656
code::ET
5757
tensors::Vector{MT}
58-
evidence::Dict{LT, Int}
58+
evidence::Dict{Int, Int}
5959
unity_tensors_idx::Vector{Int}
6060
end
6161

@@ -78,7 +78,7 @@ end
7878

7979
function Base.show(io::IO, tn::TensorNetworkModel)
8080
open = getiyv(tn.code)
81-
variables = join([string_var(var, open, tn.evidence) for var in tn.vars], ", ")
81+
variables = join([string_var(var, open, tn.evidence) for var in get_vars(tn)], ", ")
8282
tc, sc, rw = contraction_complexity(tn)
8383
println(io, "$(typeof(tn))")
8484
println(io, "variables: $variables")
@@ -128,25 +128,24 @@ function TensorNetworkModel(
128128
tensors = Array{ET}[[ones(ET, [model.cards[i] for i in lb]...) for lb in unity_tensors_labels]..., [t.vals for t in model.factors]...]
129129
size_dict = OMEinsum.get_size_dict(getixsv(rawcode), tensors)
130130
code = optimize_code(rawcode, size_dict, optimizer, simplifier)
131-
return TensorNetworkModel(collect(Int, 1:model.nvars), code, tensors, evidence, collect(Int, 1:length(unity_tensors_labels)))
131+
return TensorNetworkModel(model.nvars, code, tensors, evidence, collect(Int, 1:length(unity_tensors_labels)))
132132
end
133133

134134
"""
135135
$(TYPEDSIGNATURES)
136136
137137
Get the variables in this tensor network, they are also known as legs, labels, or degree of freedoms.
138138
"""
139-
get_vars(tn::TensorNetworkModel)::Vector = tn.vars
139+
get_vars(tn::TensorNetworkModel)::Vector = 1:tn.nvars
140140

141141
"""
142142
$(TYPEDSIGNATURES)
143143
144-
Get the cardinalities of variables in this tensor network.
144+
Get the ardinalities of variables in this tensor network.
145145
"""
146146
function get_cards(tn::TensorNetworkModel; fixedisone = false)::Vector
147-
vars = get_vars(tn)
148147
size_dict = OMEinsum.get_size_dict(getixsv(tn.code), tn.tensors)
149-
[fixedisone && haskey(tn.evidence, vars[k]) ? 1 : size_dict[vars[k]] for k in eachindex(vars)]
148+
[fixedisone && haskey(tn.evidence, k) ? 1 : size_dict[k] for k in 1:tn.nvars]
150149
end
151150

152151
chevidence(tn::TensorNetworkModel, evidence) = TensorNetworkModel(tn.vars, tn.code, tn.tensors, evidence)

src/cspmodels.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ function update_temperature(tnet::TensorNetworkModel, problem::ConstraintSatisfa
5050
tensors, ixs = generate_tensors(β, problem)
5151
@assert tnet.unity_tensors_idx == collect(1:length(tnet.unity_tensors_idx)) "The target tensor network can not be updated! Got `unity_tensors_idx = $(tnet.unity_tensors_idx)`"
5252
alltensors = [tnet.tensors[tnet.unity_tensors_idx]..., tensors...]
53-
return TensorNetworkModel(tnet.vars, tnet.code, alltensors, tnet.evidence, tnet.unity_tensors_idx)
53+
return TensorNetworkModel(tnet.nvars, tnet.code, alltensors, tnet.evidence, tnet.unity_tensors_idx)
5454
end
5555

5656
function MMAPModel(problem::ConstraintSatisfactionProblem, β::Real;

src/map.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,11 @@ $(TYPEDSIGNATURES)
5353
Returns the largest log-probability and the most probable configuration.
5454
"""
5555
function most_probable_config(tn::TensorNetworkModel; usecuda = false)::Tuple{Real, Vector}
56-
vars = get_vars(tn)
57-
tensor_indices = check_queryvars(tn, [[v] for v in vars])
56+
tensor_indices = check_queryvars(tn, [[v] for v in 1:tn.nvars])
5857
tensors = map(t -> Tropical.(log.(t)), adapt_tensors(tn; usecuda, rescale = false))
5958
logp, grads = cost_and_gradient(tn.code, tensors)
6059
# use Array to convert CuArray to CPU arrays
61-
return content(Array(logp)[]), map(k -> haskey(tn.evidence, vars[k]) ? tn.evidence[vars[k]] : argmax(grads[tensor_indices[k]]) - 1, 1:length(vars))
60+
return content(Array(logp)[]), map(k -> haskey(tn.evidence, k) ? tn.evidence[k] : argmax(grads[tensor_indices[k]]) - 1, 1:tn.nvars)
6261
end
6362
# check if the queryvars are included in the unity tensors labels, if yes, return the indices of the unity tensors
6463
function check_queryvars(tn::TensorNetworkModel, queryvars::AbstractVector{Vector{Int}})

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ function random_matrix_product_state(::Type{T}, n::Int, chi::Int, d::Int=2) wher
352352
push!(ixs_bra, [virtual_indices_bra[n-1], physical_indices[n]])
353353
tensors, ixs = [tensors..., conj.(tensors)...], [ixs_ket..., ixs_bra...]
354354
return TensorNetworkModel(
355-
collect(1:3n-2),
355+
3n-2,
356356
optimize_code(DynamicEinCode(ixs, Int[]), OMEinsum.get_size_dict(ixs, tensors), GreedyMethod()),
357357
tensors,
358358
Dict{Int, Int}(),

test/mar.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using Test
22
using OMEinsum
3-
using KaHyPar
43
using TensorInference
54

65
@testset "composite number" begin

test/pr.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using Test
22
using OMEinsum
3-
using KaHyPar
43
using TensorInference
54

65
@testset "UAI Reference Solution Comparison" begin

0 commit comments

Comments
 (0)