Skip to content

Commit ed3e65d

Browse files
committed
initial MP decoder tests
1 parent 63edd8f commit ed3e65d

File tree

3 files changed

+136
-84
lines changed

3 files changed

+136
-84
lines changed

src/LDPC/MP_decoders.jl

Lines changed: 29 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,9 @@ function _SP_check_node_message(c::Int, v::Int, iter::Int, check_adj_list::Vecto
9898
end
9999

100100
function ϕ_test(x::Real)
101-
# x >= 0 ? (return -log(tanh(0.5 * x));) : (return log(tanh(-0.5 * x));)
102-
x >= 0 ? (return log((exp(x) + 1)/(exp(x) - 1));) : (return -log((exp(-x) + 1)/(exp(-x) - 1));)
101+
# TODO why does the other one produce NaN
102+
x >= 0 ? (return -log(tanh(0.5 * x));) : (return log(tanh(-0.5 * x));)
103+
# x >= 0 ? (return log((exp(x) + 1)/(exp(x) - 1));) : (return -log((exp(-x) + 1)/(exp(-x) - 1));)
103104
end
104105

105106
(a::Float64, b::Float64) = log((1 + exp(a + b)) / (exp(a) + exp(b)))
@@ -539,15 +540,15 @@ end
539540
function _channel_init_BSC(v::Vector{<: Integer}, p::Float64)
540541
temp = log((1 - p) / p)
541542
chn_init = zeros(Float64, length(v))
542-
for i in 1:ncols(v)
543+
for i in 1:length(v)
543544
@inbounds chn_init[i] = (-1)^v[i] * temp
544545
end
545546
return chn_init
546547
end
547548

548549
function _channel_init_BSC!(var_to_check_messages::Matrix{Float64}, v::CTMatrixTypes, p::Float64)
549550
temp = log((1 - p) / p)
550-
@inbounds for i in 1:ncols(v)
551+
@inbounds for i in 1:nrows(v)
551552
iszero(v[i]) ? (var_to_check_messages[i, 1] = temp;) : (var_to_check_messages[i, 1] = -temp;)
552553
end
553554
return nothing
@@ -564,15 +565,15 @@ end
564565
function _channel_init_BAWGNC_SP(v::Vector{<: AbstractFloat}, σ::Float64)
565566
temp = 2 / σ^2
566567
chn_init = zeros(Float64, length(v))
567-
for i in 1:ncols(v)
568+
for i in 1:nrows(v)
568569
@inbounds chn_init[i] = temp * v[i]
569570
end
570571
return chn_init
571572
end
572573

573574
function _channel_init_BAWGNC_SP!(var_to_check_messages::Matrix{Float64}, v::Vector{<: AbstractFloat}, σ::Float64)
574575
temp = 2 / σ^2
575-
@inbounds for i in 1:ncols(v)
576+
@inbounds for i in 1:nrows(v)
576577
var_to_check_messages[i, 1] = temp * v[i]
577578
end
578579
return nothing
@@ -652,7 +653,7 @@ function _message_passing_init_fast(H::Union{Matrix{S}, T}, v::Union{Vector{S},
652653
all(1 bit num_var for bit in erasures) ||
653654
throw(ArgumentError("Invalid bit index in erasures"))
654655
@inbounds for i in erasures
655-
var_to_check_messages[i, 1] = 0.0
656+
var_to_check_messages[i, 1] = 1e-10
656657
end
657658
end
658659

@@ -679,7 +680,7 @@ function _message_passing_init(H::Union{Matrix{S}, T}, v::Union{Vector{S}, Vecto
679680
end
680681
num_check, num_var = size(H_Int)
681682
num_check > 0 && num_var > 0 || throw(ArgumentError("Input matrix of improper dimension"))
682-
683+
683684
len_v = length(v)
684685
if len_v == num_var
685686
syndrome_based = false
@@ -692,7 +693,7 @@ function _message_passing_init(H::Union{Matrix{S}, T}, v::Union{Vector{S}, Vecto
692693
else
693694
throw(ArgumentError("Vector has incorrect dimension"))
694695
end
695-
696+
696697
check_adj_list = [Int[] for _ in 1:num_check]
697698
var_adj_list = [Int[] for _ in 1:num_var]
698699
for r in 1:num_check
@@ -725,7 +726,7 @@ function _message_passing_init(H::Union{Matrix{S}, T}, v::Union{Vector{S}, Vecto
725726
elseif syndrome_based && ismissing(chn_inits) && isa(chn, BinarySymmetricChannel)
726727
temp = log((1 - chn.param) / chn.param)
727728
# var_to_check_messages[:, 1] .= temp
728-
chn_inits_2 = [iszero(v[i]) ? temp : -temp for i in 1:num_vars]
729+
chn_inits_2 = [temp for _ in 1:num_var]
729730
elseif !ismissing(chn_inits)
730731
length(chn_inits) num_var && throw(ArgumentError("Channel inputs has wrong size"))
731732
# var_to_check_messages[:, 1] .= chn_inits
@@ -743,7 +744,7 @@ function _message_passing_init(H::Union{Matrix{S}, T}, v::Union{Vector{S}, Vecto
743744
all(1 bit num_var for bit in erasures) ||
744745
throw(ArgumentError("Invalid bit index in erasures"))
745746
@inbounds for i in erasures
746-
chn_inits_2[i] = 0.0
747+
chn_inits_2[i] = 1e-10
747748
end
748749
end
749750

@@ -813,7 +814,7 @@ function _message_passing_init_Int(H::Union{Matrix{S}, T}, v::Union{Vector{S},
813814
# all(1 ≤ bit ≤ num_var for bit in erasures) ||
814815
# throw(ArgumentError("Invalid bit index in erasures"))
815816
# @inbounds for i in erasures
816-
# chn_inits[i] = 0.0
817+
# chn_inits[i] = 1e-10
817818
# end
818819
# end
819820

@@ -916,7 +917,7 @@ function _message_passing_init_decimation(H::T, v::T, chn::AbstractClassicalNois
916917
all(1 bit num_var for bit in erasures) ||
917918
throw(ArgumentError("Invalid bit index in erasures"))
918919
@inbounds for i in erasures
919-
chn_inits[i] = 0.0
920+
chn_inits[i] = 1e-10
920921
end
921922
end
922923

@@ -967,9 +968,9 @@ function _message_passing(H::Matrix{T}, syndrome::Union{Missing, Vector{T}},
967968
end
968969
LinearAlgebra.mul!(syn, H, current_bits)
969970
if !ismissing(syndrome)
970-
all(syn[i] .% 2 == syndrome[i] for i in 1:num_check) && return true, current_bits, 1
971+
all(syn[i] .% 2 == syndrome[i] for i in 1:num_check) && return true, current_bits, 1, totals
971972
else
972-
all(iszero(syn[i] .% 2) for i in 1:num_check) && return true, current_bits, 1
973+
all(iszero(syn[i] .% 2) for i in 1:num_check) && return true, current_bits, 1, totals
973974
end
974975

975976
iter = 2
@@ -1009,9 +1010,9 @@ function _message_passing(H::Matrix{T}, syndrome::Union{Missing, Vector{T}},
10091010

10101011
LinearAlgebra.mul!(syn, H, current_bits)
10111012
if !ismissing(syndrome)
1012-
all(syn[i] .% 2 == syndrome[i] for i in 1:num_check) && return true, current_bits, iter
1013+
all(syn[i] .% 2 == syndrome[i] for i in 1:num_check) && return true, current_bits, iter, totals
10131014
else
1014-
all(iszero(syn[i] .% 2) for i in 1:num_check) && return true, current_bits, iter
1015+
all(iszero(syn[i] .% 2) for i in 1:num_check) && return true, current_bits, iter, totals
10151016
end
10161017

10171018
if schedule == :parallel
@@ -1091,9 +1092,9 @@ function _message_passing_layered(H::Matrix{T}, syndrome::Union{Missing, Vector{
10911092
# iteration done, check if converged
10921093
LinearAlgebra.mul!(syn, H, current_bits)
10931094
if !ismissing(syndrome)
1094-
all(syn[i] % 2 == syndrome[i] for i in 1:num_check) && return true, current_bits, iter
1095+
all(syn[i] % 2 == syndrome[i] for i in 1:num_check) && return true, current_bits, iter, totals
10951096
else
1096-
all(iszero(syn[i] % 2) for i in 1:num_check) && return true, current_bits, iter
1097+
all(iszero(syn[i] % 2) for i in 1:num_check) && return true, current_bits, iter, totals
10971098
end
10981099

10991100
iter += 1
@@ -1126,7 +1127,7 @@ function _message_passing_fast(H_Int::Matrix{UInt8}, v::Matrix{UInt8}, syndrome_
11261127
# TODO fix phi_inv here to not need this (-1)^sign term
11271128
temp = S - phi(Q_temp)
11281129
# BUG? seems to converge if I put the minus sign before (-1) here?!?!?!
1129-
check_to_var_messages[c][i] = -(-1)^sign(temp) * phi_inv(temp)
1130+
check_to_var_messages[c][i] = (-1)^sign(temp) * phi_inv(temp)
11301131
var_to_check_messages[v, curr_iter] = Q_temp + check_to_var_messages[c][i]
11311132
end
11321133
end
@@ -1138,9 +1139,9 @@ function _message_passing_fast(H_Int::Matrix{UInt8}, v::Matrix{UInt8}, syndrome_
11381139

11391140
LinearAlgebra.mul!(syn, H_Int, current_bits)
11401141
if syndrome_based
1141-
all(syn[i] % 2 == v[i, 1] for i in 1:num_check) && return true, current_bits, iter
1142+
all(syn[i] % 2 == v[i, 1] for i in 1:num_check) && return true, current_bits, iter, var_to_check_messages[:, curr_iter]
11421143
else
1143-
all(iszero(syn[i] % 2) for i in 1:num_check) && return true, current_bits, iter
1144+
all(iszero(syn[i] % 2) for i in 1:num_check) && return true, current_bits, iter, var_to_check_messages[:, curr_iter]
11441145
end
11451146

11461147
if schedule == :parallel
@@ -1155,7 +1156,7 @@ function _message_passing_fast(H_Int::Matrix{UInt8}, v::Matrix{UInt8}, syndrome_
11551156
# for i in 1:num_var
11561157
current_bits[i] = var_to_check_messages[i, curr_iter] >= 0 ? 0 : 1
11571158
end
1158-
return false, current_bits, iter, totals
1159+
return false, current_bits, iter, var_to_check_messages[:, curr_iter]
11591160
end
11601161

11611162
function _message_passing_fast_layered(H_Int::Matrix{UInt8}, v::Matrix{UInt8}, syndrome_based::Bool,
@@ -1201,9 +1202,9 @@ function _message_passing_fast_layered(H_Int::Matrix{UInt8}, v::Matrix{UInt8}, s
12011202

12021203
LinearAlgebra.mul!(syn, H_Int, current_bits)
12031204
if syndrome_based
1204-
all(syn[i] % 2 == v[i, 1] for i in 1:num_check) && return true, current_bits, iter
1205+
all(syn[i] % 2 == v[i, 1] for i in 1:num_check) && return true, current_bits, iter, var_to_check_messages[:, curr_iter]
12051206
else
1206-
all(iszero(syn[i] % 2) for i in 1:num_check) && return true, current_bits, iter
1207+
all(iszero(syn[i] % 2) for i in 1:num_check) && return true, current_bits, iter, var_to_check_messages[:, curr_iter]
12071208
end
12081209

12091210
if schedule == :parallel
@@ -1217,7 +1218,7 @@ function _message_passing_fast_layered(H_Int::Matrix{UInt8}, v::Matrix{UInt8}, s
12171218
@inbounds for i in 1:num_var
12181219
current_bits[i] = var_to_check_messages[i, curr_iter] >= 0 ? 0 : 1
12191220
end
1220-
return false, current_bits, iter, totals
1221+
return false, current_bits, iter, var_to_check_messages[:, curr_iter]
12211222
end
12221223

12231224
# significant speedups seperating the float and int code
@@ -1322,7 +1323,7 @@ function _message_passing_Int(H::Matrix{T}, syndrome::Union{Missing, Vector{T}},
13221323
iter += 1
13231324
end
13241325

1325-
return false, current_bits, iter, totals
1326+
return false, current_bits, iter
13261327
end
13271328

13281329
function _message_passing_decimation(H::Matrix{T}, w::Vector{T}, chn_inits::Union{Missing,
@@ -1392,8 +1393,7 @@ function _message_passing_decimation(H::Matrix{T}, w::Vector{T}, chn_inits::Unio
13921393
end
13931394

13941395
LinearAlgebra.mul!(syn, H, current_bits)
1395-
iszero(syn .% 2) && return true, current_bits, iter, var_to_check_messages,
1396-
check_to_var_messages
1396+
iszero(syn .% 2) && return true, current_bits, iter, totals
13971397

13981398
if algorithm == :guided && iszero(iter % guided_rounds)
13991399
val, index = findmax(totals)
@@ -1533,23 +1533,3 @@ function balance_of_layered_schedule(sch::Vector{Vector{Int}})
15331533
end
15341534
return γ
15351535
end
1536-
1537-
# output should end up [1 1 1 0 0 0 0]
1538-
1539-
# H = matrix(GF(2), [1 1 0 1 1 0 0; 1 0 1 1 0 1 0; 0 1 1 1 0 0 1]);
1540-
# v = matrix(GF(2), 7, 1, [1, 1, 0, 0, 0, 0, 0]);
1541-
# syn = H * v;
1542-
# nm = AbstractClassicalNoiseChannel(:BSC, 1/7);
1543-
# decimated_bits_values = [(1, base_ring(v)(1))];
1544-
# flag, out, iter = sum_product(H, v, nm); flag
1545-
# flag, out, iter = sum_product_syndrome(H, syn, nm); flag
1546-
# flag, out, iter = CodingTheory.sum_product_decimation(H, v, nm, decimated_bits_values); flag
1547-
1548-
# TODO: all of the min-sum versions fail
1549-
# flag, out, iter = min_sum(H, v, nm)
1550-
# flag, out, iter = min_sum_syndrome(H, syn, nm); flag
1551-
# flag, out, iter = CodingTheory.min_sum_decimation(H, v, nm, decimated_bits_values); flag
1552-
1553-
# flag, out, iter = Gallager_A(H, v); flag
1554-
# flag, out, iter = Gallager_B(H, v); flag
1555-
# TODO: B fails

src/Quantum/decoders/OTF.jl

Lines changed: 28 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ function ordered_Tanner_forest(H::T, v::T, chn::AbstractClassicalNoiseChannel, B
88
max_iter::Int = 100, chn_inits::Union{Missing, Vector{Float64}} = missing, schedule::Symbol =
99
:parallel, rand_sched::Bool = false, erasures::Vector{Int} = Int[]) where T <: CTMatrixTypes
1010

11+
# TODO add MS_C
12+
# TODO decimation
1113
BP_alg (:SP, :MS) || throw(ArgumentError("`BP_alg` should be `:SP` or `:MS`"))
1214
Int(order(base_ring(H))) == 2 || throw(ArgumentError("Currently only implemented for binary codes"))
1315
nr, nc = size(H)
@@ -18,20 +20,13 @@ function ordered_Tanner_forest(H::T, v::T, chn::AbstractClassicalNoiseChannel, B
1820
schedule == :flooding && (schedule = :parallel;)
1921
schedule == :semiserial && (schedule = :layered;)
2022

21-
H_Int, v_Int, syndrome_based, check_adj_list, check_to_var_messages, var_to_check_messages,
22-
current_bits, syn = _message_passing_init_fast(H, v, chn, :BP_alg, chn_inits, :serial,
23-
erasures)
24-
25-
# 1. Run BP
26-
# 2. Sort soft info from output
27-
# 3. Make hypergraph (implicit step)
28-
# 4. start with least reliable node and record edges to neighbors
29-
# 5. Move to next least reliable node, store same making sure no loop is added
30-
# 6. Repeat until done
31-
# 7. For hyperedges that are never added, set vertex to erased in original Tanner graph
32-
# 8. Run BP
23+
# TODO search for wt 1 columns and add new row
3324

3425
if BP_alg == :SP
26+
H_Int, v_Int, syndrome_based, check_adj_list, check_to_var_messages, var_to_check_messages,
27+
current_bits, syn = _message_passing_init_fast(H, v, chn, :BP_alg, chn_inits, :serial,
28+
erasures)
29+
3530
if schedule == :layered
3631
layers = layered_schedule(H, schedule = schedule, random = rand_sched)
3732
flag, e, _, posteriors = _message_passing_fast_layered(H_Int, v_Int, syndrome_based,
@@ -43,28 +38,29 @@ function ordered_Tanner_forest(H::T, v::T, chn::AbstractClassicalNoiseChannel, B
4338
ϕ_test, ϕ_test, max_iter)
4439
end
4540
elseif BP_alg == :MS
46-
# TODO _MS_check_node_message
47-
if schedule == :layered
48-
layers = layered_schedule(H, schedule = schedule, random = rand_sched)
49-
flag, e, _, posteriors = _message_passing_fast_layered(H_Int, v_Int, syndrome_based,
50-
check_adj_list, check_to_var_messages, var_to_check_messages, current_bits, syn,
51-
ϕ_test, ϕ_test, max_iter, layers)
52-
else
53-
flag, e, _, posteriors = _message_passing_fast(H_Int, v_Int, syndrome_based,
54-
check_adj_list, check_to_var_messages, var_to_check_messages, current_bits, syn,
55-
ϕ_test, ϕ_test, max_iter)
56-
end
41+
H_Int, _, var_adj_list, check_adj_list, chn_inits_2, check_to_var_messages,
42+
var_to_check_messages, current_bits, totals, syn = _message_passing_init(H, v, chn,
43+
:MS, chn_inits, schedule, erasures)
44+
layers = layered_schedule(H, schedule = schedule, random = rand_sched)
45+
flag, e, _, posteriors = _message_passing_layered(H_Int, missing, chn_inits_2,
46+
_MS_check_node_message, var_adj_list, check_adj_list, max_iter, schedule,
47+
current_bits, totals, syn, check_to_var_messages, var_to_check_messages,
48+
attenuation, layers)
5749
end
5850

5951
if !flag
6052
# initial BP did not converge
6153
ordered_indices = sortperm(posteriors, rev = true)
6254
erased_columns = _select_erased_columns(H, posteriors, ordered_indices)
6355
for i in erased_columns
64-
posteriors[i] = 0.0
56+
posteriors[i] = 1e-10
6557
end
6658

6759
if BP_alg == :SP
60+
H_Int, v_Int, syndrome_based, check_adj_list, check_to_var_messages, var_to_check_messages,
61+
current_bits, syn = _message_passing_init_fast(H, v, chn, :BP_alg, chn_inits, :serial,
62+
erasures)
63+
6864
if schedule == :layered
6965
layers = layered_schedule(H, schedule = schedule, random = rand_sched)
7066
flag, e, _, posteriors = _message_passing_fast_layered(H_Int, v_Int, syndrome_based,
@@ -76,17 +72,14 @@ function ordered_Tanner_forest(H::T, v::T, chn::AbstractClassicalNoiseChannel, B
7672
ϕ_test, ϕ_test, max_iter)
7773
end
7874
elseif BP_alg == :MS
79-
# TODO _MS_check_node_message
80-
if schedule == :layered
81-
layers = layered_schedule(H, schedule = schedule, random = rand_sched)
82-
flag, e, _, posteriors = _message_passing_fast_layered(H_Int, v_Int, syndrome_based,
83-
check_adj_list, check_to_var_messages, var_to_check_messages, current_bits, syn,
84-
ϕ_test, ϕ_test, max_iter, layers)
85-
else
86-
flag, e, _, posteriors = _message_passing_fast(H_Int, v_Int, syndrome_based,
87-
check_adj_list, check_to_var_messages, var_to_check_messages, current_bits, syn,
88-
ϕ_test, ϕ_test, max_iter)
89-
end
75+
H_Int, _, var_adj_list, check_adj_list, chn_inits_2, check_to_var_messages,
76+
var_to_check_messages, current_bits, totals, syn = _message_passing_init(H, v, chn,
77+
:MS, chn_inits, schedule, erasures)
78+
layers = layered_schedule(H, schedule = schedule, random = rand_sched)
79+
flag, e, _, posteriors = _message_passing_layered(H_Int, missing, chn_inits_2,
80+
_MS_check_node_message, var_adj_list, check_adj_list, max_iter, schedule,
81+
current_bits, totals, syn, check_to_var_messages, var_to_check_messages,
82+
attenuation, layers)
9083
end
9184
end
9285

0 commit comments

Comments
 (0)