|
1 | 1 | struct BPState{T, VT<:AbstractVector{T}} |
2 | 2 | t2v::Vector{Vector{Int}} # a mapping from tensors to variables |
3 | 3 | v2t::Vector{Vector{Int}} # a mapping from variables to tensors |
4 | | - edges_vectors::Vector{Vector{VT}} # each tensor is associated with a vector of vectors, one for each neighbor |
| 4 | + tensors::Vector{AbstractArray{T}} # the tensors |
| 5 | + message_in::Vector{Vector{VT}} # for each variable, we store the incoming messages |
| 6 | + message_out::Vector{Vector{VT}} # the outgoing messages |
| 7 | +end |
| 8 | + |
| 9 | +# message_in -> message_out |
| 10 | +function process_message!(bp::BPState) |
| 11 | + for (ov, iv) in zip(bp.message_out, bp.message_in) |
| 12 | + _process_message!(ov, iv) |
| 13 | + end |
| 14 | +end |
| 15 | +function _process_message!(ov::Vector, iv::Vector) |
| 16 | + # process the message, TODO: speed up if needed! |
| 17 | + for (i, v) in enumerate(ov) |
| 18 | + fill!(v, one(eltype(v))) # clear the output vector |
| 19 | + for (j, u) in enumerate(iv) |
| 20 | + j != i && (v .*= u) |
| 21 | + end |
| 22 | + end |
| 23 | +end |
| 24 | + |
| 25 | +function collect_message!(bp::BPState) |
| 26 | + for (it, t) in enumerate(bp.t2v) |
| 27 | + _collect_message!(vectors_on_tensor(bp.message_out, bp, it), t, vectors_on_tensor(bp.message_in, bp, it)) |
| 28 | + end |
| 29 | +end |
| 30 | +# collect the vectors associated with the target tensor |
| 31 | +function vectors_on_tensor(messages, bp::BPState, it::Int) |
| 32 | + return map(bp.t2v[it]) do v |
| 33 | + # the message goes to the idx-th tensor from variable v |
| 34 | + messages[v][findfirst(==(it), bp.v2t[v])] |
| 35 | + end |
| 36 | +end |
| 37 | +function _collect_message!(vectors_out::Vector, t::AbstractArray, vectors_in::Vector) |
| 38 | + @assert length(vectors_out) == length(vectors_in) == ndims(t) |
| 39 | + # TODO: speed up if needed! |
| 40 | + code = star_code(length(vectors_in)) |
| 41 | + cost, gradient = cost_and_gradient(code, [t, vectors_in...]) |
| 42 | + for (o, g) in zip(vectors_out, gradient[2:end]) |
| 43 | + o .= g |
| 44 | + end |
| 45 | + return cost |
| 46 | +end |
| 47 | +function star_code(n::Int) |
| 48 | + ix1, ixrest = collect(1:n), [[i] for i in 1:n] |
| 49 | + ne = DynamicNestedEinsum([DynamicNestedEinsum{Int}(1), DynamicNestedEinsum{Int}(2)], DynamicEinCode([ix1, ixrest[1]], collect(2:n))) |
| 50 | + for i in 2:n |
| 51 | + ne = DynamicNestedEinsum([ne, DynamicNestedEinsum{Int}(i + 1)], DynamicEinCode([ne.eins.iy, ixrest[i]], collect(i+1:n))) |
| 52 | + end |
| 53 | + return ne |
5 | 54 | end |
6 | 55 |
|
7 | 56 | function BPState(::Type{T}, n::Int, t2v::Vector{Vector{Int}}, size_dict::Dict{Int, Int}) where T |
@@ -39,9 +88,6 @@ function belief_propagation(tn::TensorNetworkModel{T}, bpstate::BPState{T}; max_ |
39 | 88 | end |
40 | 89 | end |
41 | 90 |
|
42 | | -function tensor_product() |
43 | | -end |
44 | | - |
45 | 91 | function belief_propagation(tn::TensorNetworkModel{T}) where T |
46 | 92 | return belief_propagation(tn, BPState(T, OMEinsum.get_ixsv(tn.code), tn.size_dict)) |
47 | 93 | end |
0 commit comments