Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Hecke = "0.28, 0.29, 0.30, 0.31, 0.32, 0.33, 0.34, 0.35"
HostCPUFeatures = "0.1.6"
ILog2 = "0.2.3, 1, 2"
InteractiveUtils = "1.9"
LDPCDecoders = "0.3.2"
LDPCDecoders = "0.3.3"
LinearAlgebra = "1.9"
MacroTools = "0.5.9"
Makie = "0.20, 0.21, 0.22"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,91 @@ struct BitFlipDecoder <: AbstractSyndromeDecoder # TODO all these decoders have
bfdecoderz
end

struct BPOTSDecoder <: AbstractSyndromeDecoder
original_code
H::SparseMatrixCSC{Bool,Int}
faults_matrix::Matrix{Bool}
n::Int
s::Int
k::Int
cx::Int
cz::Int
bpots_x::LDPCDecoders.BPOTSDecoder
bpots_z::LDPCDecoders.BPOTSDecoder
end

function BPOTSDecoder(c; errorrate=nothing, maxiter=nothing, T=9, C=2.0)
# Get stabilizer matrices
Hx_raw = parity_checks_x(c)
Hz_raw = parity_checks_z(c)
H_raw = parity_checks(c)

# Convert to proper matrices
if H_raw isa Stabilizer
H_gf2 = stab_to_gf2(H_raw)
H = sparse(Bool.(H_gf2))
else
H = sparse(Bool.(H_raw))
end

# Convert X and Z matrices
if Hx_raw isa Stabilizer
Hx_gf2 = stab_to_gf2(Hx_raw)
Hz_gf2 = stab_to_gf2(Hz_raw)
Hx = sparse(Bool.(Hx_gf2))
Hz = sparse(Bool.(Hz_gf2))
else
Hx = sparse(Bool.(Hx_raw))
Hz = sparse(Bool.(Hz_raw))
end

# Get dimensions
s, n = size(H)

# For quantum codes, determine k
if c isa Toric || c isa Surface
k = c isa Toric ? 2 : 1
else
k = n - s
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a fragile piece of code, why are these special cased? Probably should use code_k code_n code_c functions.


cx = size(Hx, 1)
cz = size(Hz, 1)

# Create fault matrix
fm = BitMatrix(ones(Bool, s, 2*n))

# Create decoders
errorrate = something(errorrate, 0.0)
maxiter = something(maxiter, 200)
bpots_x = LDPCDecoders.BPOTSDecoder(Hx, errorrate, maxiter; T=T, C=C)
bpots_z = LDPCDecoders.BPOTSDecoder(Hz, errorrate, maxiter; T=T, C=C)

# Pass the original code object as the first parameter
return BPOTSDecoder(c, H, fm, n, s, k, cx, cz, bpots_x, bpots_z)
end

function decode(d::BPOTSDecoder, syndrome_sample::AbstractVector{Bool})
# Validate input size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove comments that are just a obvious statement, comments should provide insight about the code, not state things that are already obvious

length(syndrome_sample) == d.cx + d.cz ||
throw(DimensionMismatch("Syndrome length ($(length(syndrome_sample))) does not match expected size ($(d.cx + d.cz))"))

# Split syndrome
row_x = @view syndrome_sample[1:d.cx]
row_z = @view syndrome_sample[d.cx+1:d.cx+d.cz]

# Decode both parts
guess_z, conv_z = LDPCDecoders.decode!(d.bpots_x, Vector(row_x))
guess_x, conv_x = LDPCDecoders.decode!(d.bpots_z, Vector(row_z))

# Return combined X and Z errors
return vcat(guess_x, guess_z)
end

function parity_checks(d::BPOTSDecoder)
return d.H
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

repeated

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed this repetition


function BeliefPropDecoder(c; errorrate=nothing, maxiter=nothing)
Hx = parity_checks_x(c)
Hz = parity_checks_z(c)
Expand Down Expand Up @@ -64,7 +149,6 @@ function BitFlipDecoder(c; errorrate=nothing, maxiter=nothing)
isnothing(errorrate) || 0≤errorrate≤1 || error(lazy"BitFlipDecoder got an invalid error rate argument. `errorrate` must be in the range [0, 1].")
errorrate = isnothing(errorrate) ? 0.0 : errorrate
maxiter = isnothing(maxiter) ? n : maxiter
bfx = LDPCDecoders.BitFlipDecoder(Hx, errorrate, maxiter)
bfz = LDPCDecoders.BitFlipDecoder(Hz, errorrate, maxiter)

return BitFlipDecoder(H, fm, n, s, k, cx, cz, bfx, bfz)
Expand Down
Loading