Skip to content

Commit d623cd8

Browse files
esaboFe-r-oz
authored andcommitted
bug fixes for RIS
1 parent 9769b43 commit d623cd8

File tree

3 files changed

+61
-46
lines changed

3 files changed

+61
-46
lines changed

src/CodingTheory.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ include("Quantum/weight_dist.jl")
362362
# export weight_plot_CSS_X, weight_plot_CSS_Z, weight_plot_CSS, minimum_distance_X_Z,
363363
# minimum_distance_X, minimum_distance_Z, is_pure, QDistRndCSS
364364
export minimum_distance_upper_bound!, random_information_set_minimum_distance_bound!,
365-
QDistRand
365+
QDistRand!
366366

367367
#############################
368368
# Quantum/product_codes.jl

src/Quantum/weight_dist.jl

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -637,19 +637,22 @@ is returned, and if `dressed = false` then the bound is for the bare distance. R
637637
iterations and stops early if a logical of weight `d_lower_bound` is found.
638638
"""
639639
# here want to branch for graph states
640-
function random_information_set_minimum_distance_bound(S::T, which::Symbol = :full;
640+
function random_information_set_minimum_distance_bound!(S::T, which::Symbol = :full;
641641
dressed::Bool = true, max_iters::Int = 10000, verbose::Bool = false) where T <: AbstractSubsystemCode
642642

643643
which (:full, :X, :Z) || throw(DomainError(which, "Must choose `:full`, `:X` or `:Z`."))
644644
# order(field(S)) == 2 || throw(DomainError(S, "Currently only implemented for binary codes."))
645645
is_positive(max_iters) || throw(DomainError(max_iters, "The number of iterations must be a positive integer."))
646646

647-
return random_information_set_minimum_distance_bound(GaugeTrait(T), CSSTrait(T),
647+
return random_information_set_minimum_distance_bound!(GaugeTrait(T), CSSTrait(T),
648648
LogicalTrait(T), S, which, dressed, max_iters, verbose)
649649
end
650-
QDistRand(S::T, which::Symbol = :full; dressed::Bool = true, max_iters::Int = 10000, verbose::Bool = false) where T <: AbstractSubsystemCode = random_information_set_minimum_distance_bound(S, which; dressed = dressed, max_iters = max_iters, verbose = verbose)
650+
QDistRand!(S::T, which::Symbol = :full; dressed::Bool = true, max_iters::Int = 10000,
651+
verbose::Bool = false) where T <: AbstractSubsystemCode =
652+
random_information_set_minimum_distance_bound!(S, which; dressed = dressed, max_iters =
653+
max_iters, verbose = verbose)
651654

652-
function random_information_set_minimum_distance_bound(::HasGauges, ::IsCSS, ::HasLogicals,
655+
function random_information_set_minimum_distance_bound!(::HasGauges, ::IsCSS, ::HasLogicals,
653656
S::AbstractSubsystemCode, which::Symbol, dressed::Bool, max_iters::Int, verbose::Bool)
654657
# this is a CSS subsystem code
655658

@@ -752,7 +755,7 @@ function random_information_set_minimum_distance_bound(::HasGauges, ::IsCSS, ::H
752755
end
753756
end
754757

755-
uppers, founds = _RIS_bound_loop(operators_to_reduce, check_against, curr_l_bound,
758+
uppers, founds = _RIS_bound_loop!(operators_to_reduce, check_against, curr_l_bound,
756759
curr_u_bound, found, max_iterst, n, verbose)
757760
loc = argmin(uppers)
758761
if dressed
@@ -775,33 +778,38 @@ function random_information_set_minimum_distance_bound(::HasGauges, ::IsCSS, ::H
775778
return uppers[loc], founds[loc]
776779
end
777780

778-
function random_information_set_minimum_distance_bound(::HasNoGauges, ::IsCSS, ::HasLogicals,
779-
S::AbstractSubsystemCode, which::Symbol, dressed::Bool, max_iters::Int)
781+
function random_information_set_minimum_distance_bound!(::HasNoGauges, ::IsCSS, ::HasLogicals,
782+
S::AbstractSubsystemCode, which::Symbol, dressed::Bool, max_iters::Int, verbose::Bool)
780783
# CSS stabilizer code
781784

785+
n = S.n
782786
if which == :full
783787
verbose && println("Bounding the full distance")
784-
stabs = _remove_empty(_rref_no_col_swap!(_Flint_matrix_to_Julia_T_matrix(
785-
stabilizers(S), UInt8)), :rows)
786-
logs = _Flint_matrix_to_Julia_T_matrix(logicals_matrix(S), UInt8)
788+
# println(_rref_no_col_swap!(_Flint_matrix_to_Julia_T_matrix(stabilizers(S), Bool)))
789+
stabs = _Flint_matrix_to_Julia_T_matrix(stabilizers(S), Int)
790+
_rref_no_col_swap!(stabs)
791+
stabs = _remove_empty(stabs, :rows)
792+
logs = _Flint_matrix_to_Julia_T_matrix(logicals_matrix(S), Int)
787793
operators_to_reduce = vcat(stabs, logs)
788794
check_against = permutedims(logs[:, [n + 1:2n; 1:n]])
789-
curr_l_bound = S.l_bound_bare
795+
curr_l_bound = S.l_bound
796+
verbose && println("Starting lower bound: $curr_l_bound")
790797

791798
# this is done in the constructor but the logical is not stored at the time
792799
# so must redo here
793-
_, mat = _rref_no_col_swap!(operators_to_reduce)
800+
mat = _rref_no_col_swap(operators_to_reduce)
794801
anti = mat * check_against
795-
curr_u_bound, index = findmin(row_wts_symplectic(mat[findall(!iszero(anti[i:i, :]) for i in axes(anti, 1)), :]))
802+
curr_u_bound, index = findmin(row_wts_symplectic(mat[findall(!iszero(anti[i, :]) for i in axes(anti, 1)), :]))
796803
found = operators_to_reduce[index, :]
804+
verbose && println("Starting upper bound: $curr_u_bound")
797805
elseif which == :X
798806
verbose && println("Bounding the X-distance")
799807
stabs = _remove_empty(_rref_no_col_swap!(_Flint_matrix_to_Julia_T_matrix(
800808
X_stabilizers(S), UInt8)), :rows)
801809
logs = _Flint_matrix_to_Julia_T_matrix(reduce(vcat, [log[1][:, 1:n] for log in S.logicals]), UInt8)
802810
operators_to_reduce = vcat(stabs, logs)
803811
check_against = permutedims(reduce(vcat, [log[2][:, n + 1:end] for log in S.logicals])[:, [n + 1:2n; 1:n]])
804-
curr_l_bound = S.l_bound_dx_bare
812+
curr_l_bound = S.l_bound_dx
805813

806814
# this is done in the constructor but the logical is not stored at the time
807815
# so must redo here
@@ -816,7 +824,7 @@ function random_information_set_minimum_distance_bound(::HasNoGauges, ::IsCSS, :
816824
logs = _Flint_matrix_to_Julia_T_matrix(reduce(vcat, [log[2][:, n + 1:end] for log in S.logicals]), UInt8)
817825
operators_to_reduce = vcat(stabs, logs)
818826
check_against = permutedims(reduce(vcat, [log[1][:, 1:n] for log in S.logicals]))
819-
curr_l_bound = S.l_bound_dx_bare
827+
curr_l_bound = S.l_bound_dx
820828

821829
# this is done in the constructor but the logical is not stored at the time
822830
# so must redo here
@@ -826,8 +834,8 @@ function random_information_set_minimum_distance_bound(::HasNoGauges, ::IsCSS, :
826834
found = operators_to_reduce[index, :]
827835
end
828836

829-
uppers, founds = _RIS_bound_loop(operators_to_reduce, check_against, curr_l_bound,
830-
curr_u_bound, found, max_iterst, n, verbose)
837+
uppers, founds = _RIS_bound_loop!(operators_to_reduce, check_against, curr_l_bound,
838+
curr_u_bound, found, max_iters, n, verbose)
831839
loc = argmin(uppers)
832840
if which == :full
833841
S.u_bound = uppers[loc]
@@ -839,8 +847,8 @@ function random_information_set_minimum_distance_bound(::HasNoGauges, ::IsCSS, :
839847
return uppers[loc], founds[loc]
840848
end
841849

842-
function random_information_set_minimum_distance_bound(::HasGauges, ::IsNotCSS, ::HasLogicals,
843-
S::AbstractSubsystemCode, which::Symbol, dressed::Bool, max_iters::Int)
850+
function random_information_set_minimum_distance_bound!(::HasGauges, ::IsNotCSS, ::HasLogicals,
851+
S::AbstractSubsystemCode, which::Symbol, dressed::Bool, max_iters::Int, verbose::Bool)
844852
# subsystem code
845853

846854
which == :full || throw(ArguementError(which, "Parameter is not valid for non-CSS codes."))
@@ -859,7 +867,7 @@ function random_information_set_minimum_distance_bound(::HasGauges, ::IsNotCSS,
859867
# so must redo here
860868
_, mat = _rref_no_col_swap!(operators_to_reduce)
861869
anti = mat * check_against
862-
curr_u_bound, index = findmin(row_wts_symplectic(mat[findall(!iszero(anti[i:i, :]) for i in axes(anti, 1)), :]))
870+
curr_u_bound, index = findmin(row_wts_symplectic(mat[findall(!iszero(anti[i, :]) for i in axes(anti, 1)), :]))
863871
found = operators_to_reduce[index, :]
864872
else
865873
verbose && println("Bounding the full bare distance")
@@ -874,11 +882,11 @@ function random_information_set_minimum_distance_bound(::HasGauges, ::IsNotCSS,
874882
# so must redo here
875883
_, mat = _rref_no_col_swap!(operators_to_reduce)
876884
anti = mat * check_against
877-
curr_u_bound, index = findmin(row_wts_symplectic(mat[findall(!iszero(anti[i:i, :]) for i in axes(anti, 1)), :]))
885+
curr_u_bound, index = findmin(row_wts_symplectic(mat[findall(!iszero(anti[i, :]) for i in axes(anti, 1)), :]))
878886
found = operators_to_reduce[index, :]
879887
end
880888

881-
uppers, founds = _RIS_bound_loop(operators_to_reduce, check_against, curr_l_bound,
889+
uppers, founds = _RIS_bound_loop!(operators_to_reduce, check_against, curr_l_bound,
882890
curr_u_bound, found, max_iterst, n, verbose)
883891
loc = argmin(uppers)
884892
if dressed
@@ -889,8 +897,8 @@ function random_information_set_minimum_distance_bound(::HasGauges, ::IsNotCSS,
889897
return uppers[loc], founds[loc]
890898
end
891899

892-
function random_information_set_minimum_distance_bound(::HasNoGauges, ::IsNotCSS, ::HasLogicals,
893-
S::AbstractSubsystemCode, which::Symbol, dressed::Bool, max_iters::Int)
900+
function random_information_set_minimum_distance_bound!(::HasNoGauges, ::IsNotCSS, ::HasLogicals,
901+
S::AbstractSubsystemCode, which::Symbol, dressed::Bool, max_iters::Int, verbose::Bool)
894902
# stabilizer code
895903

896904
which == :full || throw(ArguementError(which, "Parameter is not valid for non-CSS codes."))
@@ -902,29 +910,31 @@ function random_information_set_minimum_distance_bound(::HasNoGauges, ::IsNotCSS
902910
operators_to_reduce = vcat(stabs, logs)
903911
check_against = permutedims(logs[:, [n + 1:2n; 1:n]])
904912
curr_l_bound = S.l_bound
913+
verbose && println("Starting lower bound: $curr_l_bound")
905914

906915
# this is done in the constructor but the logical is not stored at the time
907916
# so must redo here
908917
_, mat = _rref_no_col_swap!(operators_to_reduce)
909918
anti = mat * check_against
910919
curr_u_bound, index = findmin(row_wts_symplectic(mat[findall(!iszero(anti[i:i, :]) for i in axes(anti, 1)), :]))
911920
found = operators_to_reduce[index, :]
921+
verbose && println("Starting upper bound: $curr_u_bound")
912922

913-
uppers, founds = _RIS_bound_loop(operators_to_reduce, check_against, curr_l_bound,
914-
curr_u_bound, found, max_iterst, n, verbose)
923+
uppers, founds = _RIS_bound_loop!(operators_to_reduce, check_against, curr_l_bound,
924+
curr_u_bound, found, max_iters, n, verbose)
915925
loc = argmin(uppers)
916926
S.u_bound = uppers[loc]
917927
return uppers[loc], founds[loc]
918928
end
919929

920930
# TODO rewrite all for graph states
921931
# function random_information_set_minimum_distance_bound(::Union{HasGauges, HasNoGauges},
922-
# ::Union{IsCSS, IsNotCSS}, ::HasNoLogicals, S::AbstractSubsystemCode, which::Symbol, dressed::Bool, max_iters::Int)
932+
# ::Union{IsCSS, IsNotCSS}, ::HasNoLogicals, S::AbstractSubsystemCode, which::Symbol, dressed::Bool, max_iters::Int, verbose::Bool)
923933
# # graph state
924934

925935
# end
926936

927-
function _RIS_bound_loop(operators_to_reduce, check_against, curr_l_bound::Int, curr_u_bound::Int, found, max_iters::Int, n::Int, verbose::Bool)
937+
function _RIS_bound_loop!(operators_to_reduce, check_against, curr_l_bound::Int, curr_u_bound::Int, found, max_iters::Int, n::Int, verbose::Bool)
928938
num_thrds = Threads.nthreads()
929939
verbose && println("Detected $num_thrds threads.")
930940

@@ -933,7 +943,7 @@ function _RIS_bound_loop(operators_to_reduce, check_against, curr_l_bound::Int,
933943
founds = [found for _ in 1:num_thrds]
934944
thread_load = Int(floor(max_iters / num_thrds))
935945
Threads.@threads for t in 1:num_thrds
936-
log_test = zeros(Int, size(check_against, 2))
946+
log_test = zeros(Int, size(operators_to_reduce, 1), size(check_against, 2))
937947
for _ in 1:thread_load
938948
if flag[]
939949
perm = shuffle(1:n)
@@ -946,12 +956,13 @@ function _RIS_bound_loop(operators_to_reduce, check_against, curr_l_bound::Int,
946956
if any(!iszero, log_test[i, :])
947957
w = 0
948958
@inbounds for j in 1:n
949-
iszero(perm_ops[i. j]) && iszero(perm_ops[i, j + n]) || (w += 1;)
959+
iszero(perm_ops[i, j] % 2) && iszero(perm_ops[i, j + n] % 2) || (w += 1;)
950960
end
951961

952962
if uppers[t] > w
953963
uppers[t] = w
954-
founds[t] .= perm_ops[i, invperm!(perm2)]
964+
# maybe use invpermute! here?
965+
founds[t] .= perm_ops[i, invperm(perm2)]
955966
verbose && println("Adjusting upper bound: $w")
956967
if curr_l_bound == w
957968
verbose && println("Found a logical that matched the lower bound of $curr_l_bound")

src/utils.jl

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ end
127127
weight(v::T) where T <: Union{CTMatrixTypes, Vector{<:CTFieldElem}, Vector{S}, Adjoint{S, Vector{S}}, AbstractMatrix{S}} where S <: Integer = Hamming_weight(v)
128128
wt(v::T) where T <: Union{CTMatrixTypes, Vector{<:CTFieldElem}, Vector{S}, Adjoint{S, Vector{S}}, AbstractMatrix{S}} where S <: Integer = Hamming_weight(v)
129129

130-
# TODO polish and export
131-
function row_wts_symplectic(A::CTMatrixTypes)
130+
# TODO polish and export?
131+
function row_wts_symplectic(A::Union{CTMatrixTypes, Matrix{<: Integer}, Matrix{Bool}})
132132
nc = size(A, 2)
133133
iseven(nc) || throw(ArgumentError("Input does not have even length"))
134134
n = div(nc, 2)
@@ -311,11 +311,12 @@ end
311311
# end
312312
_Flint_matrix_to_Julia_int_vector(A) = vec(_Flint_matrix_to_Julia_int_matrix(A))
313313

314-
function _non_pivot_cols(A::CTMatrixTypes, type::Symbol=:nsp)
315-
type [:sp, :nsp]
314+
function _non_pivot_cols(A::CTMatrixTypes, type::Symbol = :nsp)
315+
type (:sp, :nsp) || throw(DomainError(type, "Parameter should be `:sp` (sparse) or `:nsp` (not sparse)."))
316+
316317
if type == :sp
317318
return setdiff(collect(1:ncols(A)), [x.pos[1] for x in A])
318-
else #if type == :nsp - not sparse
319+
else
319320
nonpivots = Vector{Int}()
320321
i = 1
321322
j = 1
@@ -391,7 +392,9 @@ end
391392
# return maxlen
392393
# end
393394

394-
function _remove_empty(A::Union{CTMatrixTypes, Matrix{<:Number}, BitMatrix}, type::Symbol)
395+
function _remove_empty(A::Union{CTMatrixTypes, Matrix{<: Number}, BitMatrix, Matrix{Bool}},
396+
type::Symbol)
397+
395398
type (:rows, :cols) || throw(ArgumentError("Unknown type in _remove_empty"))
396399

397400
del = Vector{Int}()
@@ -430,12 +433,12 @@ end
430433
_rref_no_col_swap(M::CTMatrixTypes, row_range::Base.OneTo{Int}, col_range::Base.OneTo{Int}) = _rref_no_col_swap(M, 1:row_range.stop, 1:col_range.stop)
431434
_rref_no_col_swap(M::CTMatrixTypes) = _rref_no_col_swap(M, axes(M, 1), axes(M, 2))
432435

433-
function _rref_no_col_swap(A::Union{BitMatrix, Matrix{Bool}}, row_range::UnitRange{Int} = 1:size(A, 1),
434-
col_range::UnitRange{Int} = 1:size(A, 2))
436+
function _rref_no_col_swap(A::Union{BitMatrix, Matrix{Bool}, Matrix{<: Integer}},
437+
row_range::UnitRange{Int} = 1:size(A, 1), col_range::UnitRange{Int} = 1:size(A, 2))
435438

436439
B = copy(A)
437440
_rref_no_col_swap!(B, row_range, col_range)
438-
B
441+
return B
439442
end
440443

441444
function _rref_no_col_swap!(A::CTMatrixTypes, row_range::UnitRange{Int}, col_range::UnitRange{Int})
@@ -519,8 +522,8 @@ function _rref_no_col_swap!(A::CTMatrixTypes, row_range::UnitRange{Int}, col_ran
519522
return nothing
520523
end
521524

522-
function _rref_no_col_swap!(A::Union{BitMatrix, Matrix{Bool}}, row_range::UnitRange{Int} = 1:size(A, 1),
523-
col_range::UnitRange{Int} = 1:size(A, 2))
525+
function _rref_no_col_swap!(A::Union{BitMatrix, Matrix{Bool}, Matrix{<: Integer}},
526+
row_range::UnitRange{Int} = 1:size(A, 1), col_range::UnitRange{Int} = 1:size(A, 2))
524527

525528
isempty(row_range) && return nothing
526529
isempty(col_range) && return nothing
@@ -532,7 +535,7 @@ function _rref_no_col_swap!(A::Union{BitMatrix, Matrix{Bool}}, row_range::UnitRa
532535
# find first pivot
533536
ind = 0
534537
for k in i:nr
535-
if A[k, j]
538+
if !iszero(A[k, j])
536539
ind = k
537540
break
538541
end
@@ -547,10 +550,11 @@ function _rref_no_col_swap!(A::Union{BitMatrix, Matrix{Bool}}, row_range::UnitRa
547550
# eliminate
548551
for k in row_range
549552
if k != i
550-
if A[k, j]
553+
if !iszero(A[k, j])
551554
# do a manual loop here to reduce allocations
552555
@simd for l in axes(A, 2)
553-
A[k, l] ⊻= A[i, l]
556+
# A[k, l] ⊻= A[i, l]
557+
A[k, l] = (A[k, l] + A[i, l]) % 2
554558
end
555559
end
556560
end

0 commit comments

Comments
 (0)