Skip to content

Commit 0806a5f

Browse files
committed
update
1 parent 36e047e commit 0806a5f

File tree

7 files changed

+136
-50
lines changed

7 files changed

+136
-50
lines changed

src/TensorInference.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ export MMAPModel
4141
export update_temperature
4242

4343
# belief propagation
44-
export belief_propagation
44+
export BeliefPropgation, belief_propagate
4545

4646
# utils
47-
export random_matrix_product_state
47+
export random_matrix_product_state, random_tensor_train_uai, random_matrix_product_uai
4848

4949
include("Core.jl")
5050
include("RescaledArray.jl")

src/belief.jl

Lines changed: 54 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,26 @@
1-
struct BPState{T, VT<:AbstractVector{T}}
1+
struct BeliefPropgation{T}
22
t2v::Vector{Vector{Int}} # a mapping from tensors to variables
33
v2t::Vector{Vector{Int}} # a mapping from variables to tensors
44
tensors::Vector{AbstractArray{T}} # the tensors
5+
end
6+
num_tensors(bp::BeliefPropgation) = length(bp.t2v)
7+
ProblemReductions.num_variables(bp::BeliefPropgation) = length(bp.v2t)
8+
9+
function BeliefPropgation(nvars::Int, t2v::AbstractVector{Vector{Int}}, tensors::AbstractVector{AbstractArray{T}}) where T
10+
# initialize the inverse mapping
11+
v2t = [Int[] for _ in 1:nvars]
12+
for (i, edge) in enumerate(t2v)
13+
for v in edge
14+
push!(v2t[v], i)
15+
end
16+
end
17+
return BeliefPropgation(t2v, v2t, tensors)
18+
end
19+
function BeliefPropgation(uai::UAIModel{T}) where T
20+
return BeliefPropgation(uai.nvars, [collect(Int, f.vars) for f in uai.factors], AbstractArray{T}[f.vals for f in uai.factors])
21+
end
22+
23+
struct BPState{T, VT<:AbstractVector{T}}
524
message_in::Vector{Vector{VT}} # for each variable, we store the incoming messages
625
message_out::Vector{Vector{VT}} # the outgoing messages
726
end
@@ -22,20 +41,20 @@ function _process_message!(ov::Vector, iv::Vector)
2241
end
2342
end
2443

25-
function collect_message!(bp::BPState)
26-
for (it, t) in enumerate(bp.t2v)
27-
_collect_message!(vectors_on_tensor(bp.message_out, bp, it), t, vectors_on_tensor(bp.message_in, bp, it))
44+
function collect_message!(bp::BeliefPropgation, state::BPState)
45+
for it in 1:num_tensors(bp)
46+
_collect_message!(vectors_on_tensor(state.message_out, bp, it), bp.tensors[it], vectors_on_tensor(state.message_in, bp, it))
2847
end
2948
end
3049
# collect the vectors associated with the target tensor
31-
function vectors_on_tensor(messages, bp::BPState, it::Int)
50+
function vectors_on_tensor(messages, bp::BeliefPropgation, it::Int)
3251
return map(bp.t2v[it]) do v
3352
# the message goes to the idx-th tensor from variable v
3453
messages[v][findfirst(==(it), bp.v2t[v])]
3554
end
3655
end
3756
function _collect_message!(vectors_out::Vector, t::AbstractArray, vectors_in::Vector)
38-
@assert length(vectors_out) == length(vectors_in) == ndims(t)
57+
@assert length(vectors_out) == length(vectors_in) == ndims(t) "dimensions mismatch: $(length(vectors_out)), $(length(vectors_in)), $(ndims(t))"
3958
# TODO: speed up if needed!
4059
code = star_code(length(vectors_in))
4160
cost, gradient = cost_and_gradient(code, [t, vectors_in...])
@@ -44,6 +63,8 @@ function _collect_message!(vectors_out::Vector, t::AbstractArray, vectors_in::Ve
4463
end
4564
return cost
4665
end
66+
67+
# star code: contract a tensor with multiple vectors, one for each dimension
4768
function star_code(n::Int)
4869
ix1, ixrest = collect(1:n), [[i] for i in 1:n]
4970
ne = DynamicNestedEinsum([DynamicNestedEinsum{Int}(1), DynamicNestedEinsum{Int}(2)], DynamicEinCode([ix1, ixrest[1]], collect(2:n)))
@@ -53,41 +74,38 @@ function star_code(n::Int)
5374
return ne
5475
end
5576

56-
function BPState(::Type{T}, n::Int, t2v::Vector{Vector{Int}}, size_dict::Dict{Int, Int}) where T
57-
v2t = [Int[] for _ in 1:n]
58-
edges_vectors = [Vector{VT}[] for _ in 1:n]
59-
for (i, edge) in enumerate(t2v)
60-
for v in edge
61-
push!(v2t[v], i)
62-
push!(edges_vectors[i], ones(T, size_dict[v]))
63-
end
77+
function initial_state(bp::BeliefPropgation{T}) where T
78+
size_dict = OMEinsum.get_size_dict(bp.t2v, bp.tensors)
79+
edges_vectors = Vector{Vector{T}}[]
80+
for (i, tids) in enumerate(bp.v2t)
81+
push!(edges_vectors, [ones(T, size_dict[i]) for _ in 1:length(tids)])
6482
end
65-
return BPState(t2v, v2t, edges_vectors)
83+
return BPState(deepcopy(edges_vectors), edges_vectors)
6684
end
6785

6886
# belief propagation, update the tensors on the edges of the tensor network
69-
function belief_propagation(tn::TensorNetworkModel{T}, bpstate::BPState{T}; max_iter::Int=100, tol::Float64=1e-6) where T
70-
# collect the messages from the neighbors
71-
messages = [similar(bpstate.edges_vectors[it]) for it in 1:length(bpstate.t2v)]
72-
for (it, vs) in enumerate(bpstate.t2v)
73-
for (iv, v) in enumerate(vs)
74-
messages[it][iv] = tn.tensors[v]
75-
end
76-
end
77-
# update the tensors on the edges of the tensor network
78-
for (it, vs) in enumerate(bpstate.t2v)
79-
# update the tensor
80-
for (iv, v) in enumerate(vs)
81-
bpstate.edges_vectors[it][iv] = zeros(T, size_dict[v])
82-
for (j, w) in enumerate(vs)
83-
if j != iv
84-
bpstate.edges_vectors[it][iv] += messages[j][iv] * messages[j][iv]
85-
end
86-
end
87+
function belief_propagate(bp::BeliefPropgation; max_iter::Int=100, tol::Float64=1e-6)
88+
state = initial_state(bp)
89+
info = belief_propagate!(bp, state; max_iter=max_iter, tol=tol)
90+
return state, info
91+
end
92+
struct BPInfo
93+
converged::Bool
94+
iterations::Int
95+
end
96+
function belief_propagate!(bp::BeliefPropgation, state::BPState{T}; max_iter::Int=100, tol::Float64=1e-6) where T
97+
for i in 1:max_iter
98+
process_message!(state)
99+
collect_message!(bp, state)
100+
# check convergence
101+
if all(iv -> all(it -> isapprox(state.message_out[iv][it], state.message_in[iv][it], atol=tol), 1:length(bp.v2t[iv])), 1:num_variables(bp))
102+
return BPInfo(true, i)
87103
end
88104
end
105+
return BPInfo(false, max_iter)
89106
end
90107

91-
function belief_propagation(tn::TensorNetworkModel{T}) where T
92-
return belief_propagation(tn, BPState(T, OMEinsum.get_ixsv(tn.code), tn.size_dict))
93-
end
108+
# if BP is exact and converged (e.g. tree like), the result should be the same as the tensor network contraction
109+
function contraction_results(state::BPState{T}) where T
110+
return [sum(reduce((x, y) -> x .* y, mi)) for mi in state.message_in]
111+
end

src/utils.jl

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,12 @@ connected in a chain.
334334
- `d` is the dimension of the physical indices.
335335
"""
336336
function random_matrix_product_state(::Type{T}, n::Int, chi::Int, d::Int=2) where T
337+
uai = random_matrix_product_uai(T, n, chi, d)
338+
return TensorNetworkModel(uai; optimizer=GreedyMethod())
339+
end
340+
random_matrix_product_state(n::Int, chi::Int, d::Int=2) = random_matrix_product_state(ComplexF64, n, chi, d)
341+
342+
function random_matrix_product_uai(::Type{T}, n::Int, chi::Int, d::Int=2) where T
337343
# chi ^ (n-1) * (variance^n)^2 == 1/d^n
338344
variance = d^(-1/2) * chi^(-1/2+1/2n)
339345
tensors = Any[randn(T, d, chi) .* variance]
@@ -351,12 +357,41 @@ function random_matrix_product_state(::Type{T}, n::Int, chi::Int, d::Int=2) wher
351357
push!(ixs_ket, [virtual_indices_ket[n-1], physical_indices[n]])
352358
push!(ixs_bra, [virtual_indices_bra[n-1], physical_indices[n]])
353359
tensors, ixs = [tensors..., conj.(tensors)...], [ixs_ket..., ixs_bra...]
354-
return TensorNetworkModel(
355-
3n-2,
356-
optimize_code(DynamicEinCode(ixs, Int[]), OMEinsum.get_size_dict(ixs, tensors), GreedyMethod()),
357-
tensors,
358-
Dict{Int, Int}(),
359-
collect(1:n)
360+
size_dict = OMEinsum.get_size_dict(ixs, tensors)
361+
nvars = 3n-2
362+
return UAIModel(
363+
nvars,
364+
[size_dict[i] for i=1:nvars],
365+
[Factor((ixs[i]...,), tensors[i]) for i in 1:length(tensors)]
360366
)
361367
end
362-
random_matrix_product_state(n::Int, chi::Int, d::Int=2) = random_matrix_product_state(ComplexF64, n, chi, d)
368+
369+
370+
"""
371+
$TYPEDSIGNATURES
372+
373+
Tensor train (TT) is a tensor network model that is widely used in quantum
374+
many-body physics. This model is different from the matrix product state (MPS)
375+
in that it does not have an extra copy for representing the bra state.
376+
"""
377+
function random_tensor_train_uai(::Type{T}, n::Int, chi::Int, d::Int=2) where T
378+
# chi ^ (n-1) * (variance^n)^2 == 1/d^n
379+
variance = d^(-1/2) * chi^(-1/2+1/2n)
380+
tensors = Any[randn(T, d, chi) .* variance]
381+
physical_indices = collect(1:n)
382+
virtual_indices = collect(n+1:2n-1)
383+
ixs = [[physical_indices[1], virtual_indices[1]]]
384+
for i = 2:n-1
385+
push!(tensors, randn(T, chi, d, chi) .* variance)
386+
push!(ixs, [virtual_indices[i-1], physical_indices[i], virtual_indices[i]])
387+
end
388+
push!(tensors, randn(T, chi, d) .* variance)
389+
push!(ixs, [virtual_indices[n-1], physical_indices[n]])
390+
size_dict = OMEinsum.get_size_dict(ixs, tensors)
391+
nvars = 2n-1
392+
return UAIModel(
393+
nvars,
394+
[size_dict[i] for i=1:nvars],
395+
[Factor((ixs[i]...,), tensors[i]) for i in 1:length(tensors)]
396+
)
397+
end

test/belief.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using TensorInference, Test
2+
using OMEinsum
23

34
@testset "process message" begin
45
mi = [[1, 2, 3], [2, 3, 4], [3, 4, 5]]
@@ -26,11 +27,24 @@ end
2627
@test vectors_out[3] vec(kron(v2, v1)' * reshape(t, 4, 2))
2728
end
2829

30+
@testset "constructor" begin
31+
problem = problem_from_artifact("uai2014", "MAR", "Promedus", 14)
32+
uai = read_model(problem)
33+
bp = BeliefPropgation(uai)
34+
@test length(bp.v2t) == 414
35+
@test TensorInference.num_tensors(bp) == 414
36+
@test TensorInference.num_variables(bp) == length(unique(vcat([collect(Int, f.vars) for f in uai.factors]...)))
37+
end
38+
2939
@testset "belief propagation" begin
3040
n = 5
3141
chi = 3
32-
Random.seed!(140)
33-
mps = random_matrix_product_state(n, chi)
34-
model = TensorNetworkModel(mps)
35-
state = belief_propagation(model)
42+
mps_uai = TensorInference.random_tensor_train_uai(Float64, n, chi)
43+
bp = BeliefPropgation(mps_uai)
44+
@test TensorInference.initial_state(bp) isa TensorInference.BPState
45+
state, info = belief_propagate(bp)
46+
@show TensorInference.contraction_results(state)
47+
@test info.converged
48+
tnet = TensorNetworkModel(mps_uai)
49+
@show probability(tnet)[]
3650
end

test/runtests.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ end
2424
include("cspmodels.jl")
2525
end
2626

27+
@testset "utils" begin
28+
include("utils.jl")
29+
end
30+
31+
@testset "belief propagation" begin
32+
include("belief.jl")
33+
end
34+
2735
using CUDA
2836
if CUDA.functional()
2937
include("cuda.jl")

test/sampling.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ end
7171
mps = random_matrix_product_state(n, chi)
7272
num_samples = 10000
7373
ixs = OMEinsum.getixsv(mps.code)
74-
@show ixs
7574
samples = map(1:num_samples) do i
7675
sample(mps, 1; queryvars=collect(1:n)).samples[:,1]
7776
end

test/utils.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using TensorInference, Test
2+
3+
@testset "tensor train" begin
4+
tt = random_tensor_train_uai(Float64, 5, 3)
5+
@test tt.nvars == length(unique(vcat([collect(Int, f.vars) for f in tt.factors]...)))
6+
end
7+
8+
@testset "mps" begin
9+
tt = random_matrix_product_uai(Float64, 5, 3)
10+
@test tt.nvars == length(unique(vcat([collect(Int, f.vars) for f in tt.factors]...)))
11+
end
12+

0 commit comments

Comments
 (0)