Skip to content

Commit 36e047e

Browse files
committed
update
1 parent 6f96133 commit 36e047e

File tree

3 files changed

+89
-4
lines changed

3 files changed

+89
-4
lines changed

src/TensorInference.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ export MMAPModel
4040
# for ProblemReductions
4141
export update_temperature
4242

43+
# belief propagation
44+
export belief_propagation
45+
4346
# utils
4447
export random_matrix_product_state
4548

src/belief.jl

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,56 @@
11
struct BPState{T, VT<:AbstractVector{T}}
22
t2v::Vector{Vector{Int}} # a mapping from tensors to variables
33
v2t::Vector{Vector{Int}} # a mapping from variables to tensors
4-
edges_vectors::Vector{Vector{VT}} # each tensor is associated with a vector of vectors, one for each neighbor
4+
tensors::Vector{AbstractArray{T}} # the tensors
5+
message_in::Vector{Vector{VT}} # for each variable, we store the incoming messages
6+
message_out::Vector{Vector{VT}} # the outgoing messages
7+
end
8+
9+
# message_in -> message_out
10+
function process_message!(bp::BPState)
11+
for (ov, iv) in zip(bp.message_out, bp.message_in)
12+
_process_message!(ov, iv)
13+
end
14+
end
15+
function _process_message!(ov::Vector, iv::Vector)
16+
# process the message, TODO: speed up if needed!
17+
for (i, v) in enumerate(ov)
18+
fill!(v, one(eltype(v))) # clear the output vector
19+
for (j, u) in enumerate(iv)
20+
j != i && (v .*= u)
21+
end
22+
end
23+
end
24+
25+
function collect_message!(bp::BPState)
26+
for (it, t) in enumerate(bp.t2v)
27+
_collect_message!(vectors_on_tensor(bp.message_out, bp, it), t, vectors_on_tensor(bp.message_in, bp, it))
28+
end
29+
end
30+
# collect the vectors associated with the target tensor
31+
function vectors_on_tensor(messages, bp::BPState, it::Int)
32+
return map(bp.t2v[it]) do v
33+
# the message goes to the idx-th tensor from variable v
34+
messages[v][findfirst(==(it), bp.v2t[v])]
35+
end
36+
end
37+
function _collect_message!(vectors_out::Vector, t::AbstractArray, vectors_in::Vector)
38+
@assert length(vectors_out) == length(vectors_in) == ndims(t)
39+
# TODO: speed up if needed!
40+
code = star_code(length(vectors_in))
41+
cost, gradient = cost_and_gradient(code, [t, vectors_in...])
42+
for (o, g) in zip(vectors_out, gradient[2:end])
43+
o .= g
44+
end
45+
return cost
46+
end
47+
function star_code(n::Int)
48+
ix1, ixrest = collect(1:n), [[i] for i in 1:n]
49+
ne = DynamicNestedEinsum([DynamicNestedEinsum{Int}(1), DynamicNestedEinsum{Int}(2)], DynamicEinCode([ix1, ixrest[1]], collect(2:n)))
50+
for i in 2:n
51+
ne = DynamicNestedEinsum([ne, DynamicNestedEinsum{Int}(i + 1)], DynamicEinCode([ne.eins.iy, ixrest[i]], collect(i+1:n)))
52+
end
53+
return ne
554
end
655

756
function BPState(::Type{T}, n::Int, t2v::Vector{Vector{Int}}, size_dict::Dict{Int, Int}) where T
@@ -39,9 +88,6 @@ function belief_propagation(tn::TensorNetworkModel{T}, bpstate::BPState{T}; max_
3988
end
4089
end
4190

42-
function tensor_product()
43-
end
44-
4591
function belief_propagation(tn::TensorNetworkModel{T}) where T
4692
return belief_propagation(tn, BPState(T, OMEinsum.get_ixsv(tn.code), tn.size_dict))
4793
end

test/belief.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using TensorInference, Test
2+
3+
@testset "process message" begin
4+
mi = [[1, 2, 3], [2, 3, 4], [3, 4, 5]]
5+
mo_expected = [[6, 12, 20], [3, 8, 15], [2, 6, 12]]
6+
mo = similar.(mi)
7+
TensorInference._process_message!(mo, mi)
8+
@test mo == mo_expected
9+
end
10+
11+
@testset "star code" begin
12+
code = TensorInference.star_code(3)
13+
c1, c2, c3, c4 = [DynamicNestedEinsum{Int}(i) for i in 1:4]
14+
ne1 = DynamicNestedEinsum([c1, c2], DynamicEinCode([[1, 2, 3], [1]], [2, 3]))
15+
ne2 = DynamicNestedEinsum([ne1, c3], DynamicEinCode([[2, 3], [2]], [3]))
16+
ne3 = DynamicNestedEinsum([ne2, c4], DynamicEinCode([[3], [3]], Int[]))
17+
@test code == ne3
18+
t = randn(2, 2, 2)
19+
v1 = randn(2)
20+
v2 = randn(2)
21+
v3 = randn(2)
22+
vectors_out = [similar(v1), similar(v2), similar(v3)]
23+
TensorInference._collect_message!(vectors_out, t, [v1, v2, v3])
24+
@test vectors_out[1] reshape(t, 2, 4) * kron(v3, v2) # NOTE: v3 is the little end
25+
@test vectors_out[2] vec(v1' * reshape(reshape(t, 4, 2) * v3, 2, 2))
26+
@test vectors_out[3] vec(kron(v2, v1)' * reshape(t, 4, 2))
27+
end
28+
29+
@testset "belief propagation" begin
30+
n = 5
31+
chi = 3
32+
Random.seed!(140)
33+
mps = random_matrix_product_state(n, chi)
34+
model = TensorNetworkModel(mps)
35+
state = belief_propagation(model)
36+
end

0 commit comments

Comments
 (0)