Skip to content

Commit 4309e96

Browse files
committed
format document and fix tests
1 parent e1a3945 commit 4309e96

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

src/belief.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ end
1717
num_tensors(bp::BeliefPropgation) = length(bp.t2v)
1818
ProblemReductions.num_variables(bp::BeliefPropgation) = length(bp.v2t)
1919

20-
function BeliefPropgation(nvars::Int, t2v::AbstractVector{Vector{Int}}, tensors::AbstractVector{AbstractArray{T}}) where T
20+
function BeliefPropgation(nvars::Int, t2v::AbstractVector{Vector{Int}}, tensors::AbstractVector{AbstractArray{T}}) where {T}
2121
# initialize the inverse mapping
2222
v2t = [Int[] for _ in 1:nvars]
2323
for (i, edge) in enumerate(t2v)
@@ -33,11 +33,11 @@ $(TYPEDSIGNATURES)
3333
3434
Construct a belief propagation object from a [`UAIModel`](@ref).
3535
"""
36-
function BeliefPropgation(uai::UAIModel{T}) where T
36+
function BeliefPropgation(uai::UAIModel{T}) where {T}
3737
return BeliefPropgation(uai.nvars, [collect(Int, f.vars) for f in uai.factors], AbstractArray{T}[f.vals for f in uai.factors])
3838
end
3939

40-
struct BPState{T, VT<:AbstractVector{T}}
40+
struct BPState{T, VT <: AbstractVector{T}}
4141
message_in::Vector{Vector{VT}} # for each variable, we store the incoming messages
4242
message_out::Vector{Vector{VT}} # the outgoing messages
4343
end
@@ -91,12 +91,12 @@ function star_code(n::Int)
9191
ix1, ixrest = collect(1:n), [[i] for i in 1:n]
9292
ne = DynamicNestedEinsum([DynamicNestedEinsum{Int}(1), DynamicNestedEinsum{Int}(2)], DynamicEinCode([ix1, ixrest[1]], collect(2:n)))
9393
for i in 2:n
94-
ne = DynamicNestedEinsum([ne, DynamicNestedEinsum{Int}(i + 1)], DynamicEinCode([ne.eins.iy, ixrest[i]], collect(i+1:n)))
94+
ne = DynamicNestedEinsum([ne, DynamicNestedEinsum{Int}(i + 1)], DynamicEinCode([ne.eins.iy, ixrest[i]], collect((i + 1):n)))
9595
end
9696
return ne
9797
end
9898

99-
function initial_state(bp::BeliefPropgation{T}) where T
99+
function initial_state(bp::BeliefPropgation{T}) where {T}
100100
size_dict = OMEinsum.get_size_dict(bp.t2v, bp.tensors)
101101
edges_vectors = Vector{Vector{T}}[]
102102
for (i, tids) in enumerate(bp.v2t)
@@ -126,13 +126,13 @@ struct BPInfo
126126
converged::Bool
127127
iterations::Int
128128
end
129-
function belief_propagate!(bp::BeliefPropgation, state::BPState{T}; max_iter::Int=100, tol=1e-6, damping=0.2) where T
129+
function belief_propagate!(bp::BeliefPropgation, state::BPState{T}; max_iter::Int = 100, tol = 1e-6, damping = 0.2) where {T}
130130
pre_message_in = deepcopy(state.message_in)
131131
for i in 1:max_iter
132-
collect_message!(bp, state; normalize=true)
133-
process_message!(state; normalize=true, damping=damping)
132+
collect_message!(bp, state; normalize = true)
133+
process_message!(state; normalize = true, damping = damping)
134134
# check convergence
135-
if all(iv -> all(it -> isapprox(state.message_in[iv][it], pre_message_in[iv][it], atol=tol), 1:length(bp.v2t[iv])), 1:num_variables(bp))
135+
if all(iv -> all(it -> isapprox(state.message_in[iv][it], pre_message_in[iv][it], atol = tol), 1:length(bp.v2t[iv])), 1:num_variables(bp))
136136
return BPInfo(true, i)
137137
end
138138
pre_message_in = deepcopy(state.message_in)
@@ -141,13 +141,13 @@ function belief_propagate!(bp::BeliefPropgation, state::BPState{T}; max_iter::In
141141
end
142142

143143
# if BP is exact and converged (e.g. tree like), the result should be the same as the tensor network contraction
144-
function contraction_results(state::BPState{T}) where T
144+
function contraction_results(state::BPState{T}) where {T}
145145
return [sum(reduce((x, y) -> x .* y, mi)) for mi in state.message_in]
146146
end
147147

148148
"""
149149
$(TYPEDSIGNATURES)
150150
"""
151-
function marginals(state::BPState{T}) where T
151+
function marginals(state::BPState{T}) where {T}
152152
return Dict([v] => normalize!(reduce((x, y) -> x .* y, mi), 1) for (v, mi) in enumerate(state.message_in))
153153
end

test/belief.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@ using OMEinsum, LinearAlgebra
66
mo_expected = [[6.0, 12, 20], [3.0, 8, 15], [2.0, 6, 12]]
77
mo = similar.(mi)
88
TensorInference._process_message!(mo, mi, false, 0)
9-
@test all(mo .≈ mo_expected)
9+
for i in 1:length(mo)
10+
@test mo[i] mo_expected[i] atol=1e-8
11+
end
1012

1113
TensorInference._process_message!(mo, mi, true, 0)
12-
@test all(mo .≈ normalize!.(mo_expected, 1))
14+
for i in 1:length(mo)
15+
@test mo[i] normalize!(mo_expected[i], 1) atol=1e-8
16+
end
1317
end
1418

1519
@testset "star code" begin

0 commit comments

Comments
 (0)