Skip to content

Commit fe7e2cc

Browse files
committed
significant QDistRnd css stabilizer code improvements
1 parent 9b447a1 commit fe7e2cc

File tree

2 files changed

+128
-55
lines changed

2 files changed

+128
-55
lines changed

src/Quantum/weight_dist.jl

Lines changed: 105 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,7 @@ function random_information_set_minimum_distance_bound!(::HasGauges, ::IsCSS, ::
774774
end
775775
end
776776

777-
uppers, founds = _RIS_bound_loop!(operators_to_reduce, check_against, curr_l_bound,
777+
uppers, founds = _RIS_bound_loop_symp!(operators_to_reduce, check_against, curr_l_bound,
778778
curr_u_bound, found, max_iters, n, verbose)
779779
loc = argmin(uppers)
780780
if dressed
@@ -804,65 +804,61 @@ function random_information_set_minimum_distance_bound!(::HasNoGauges, ::IsCSS,
804804

805805
n = S.n
806806
if which == :full
807-
verbose && println("Bounding the full distance")
808-
# println(_rref_no_col_swap_binary!(_Flint_matrix_to_Julia_T_matrix(stabilizers(S), Bool)))
809-
stabs = _Flint_matrix_to_Julia_T_matrix(stabilizers(S), Int)
807+
# verbose && println("Bounding the full distance")
808+
# stabs = _Flint_matrix_to_Julia_T_matrix(stabilizers(S), Int)
809+
# _rref_no_col_swap_binary!(stabs)
810+
# stabs = _remove_empty(stabs, :rows)
811+
# logs = _Flint_matrix_to_Julia_T_matrix(logicals_matrix(S), Int)
812+
# operators_to_reduce = vcat(stabs, logs)
813+
# check_against = permutedims(logs[:, [n + 1:2n; 1:n]])
814+
# curr_l_bound = S.l_bound
815+
# verbose && println("Starting lower bound: $curr_l_bound")
816+
817+
# # this is done in the constructor but the logical is not stored at the time
818+
# # so must redo here
819+
# mat = _rref_no_col_swap_binary(operators_to_reduce)
820+
# anti = mat * check_against .% order(field(S))
821+
# log_locations = findall(!iszero(anti[i, :]) for i in axes(anti, 1))
822+
# curr_u_bound, index = findmin(row_wts_symplectic(mat[log_locations, :]))
823+
# found = mat[log_locations[index], :]
824+
# verbose && println("Starting upper bound: $curr_u_bound")
825+
upperx, foundx = random_information_set_minimum_distance_bound!(HasNoGauges(), IsCSS(), HasLogicals(), S, :X, dressed, max_iters, verbose)
826+
upperz, foundz = random_information_set_minimum_distance_bound!(HasNoGauges(), IsCSS(), HasLogicals(), S, :Z, dressed, max_iters, verbose)
827+
if upperx <= upperz
828+
return upperx, foundx
829+
else
830+
return upperz, foundz
831+
end
832+
else
833+
verbose && println("Bounding the X-distance")
834+
stabs = _Flint_matrix_to_Julia_T_matrix(stabilizers(S)[:, (which == :X ? (1:n) : (n + 1:2n))], Int)
810835
_rref_no_col_swap_binary!(stabs)
811836
stabs = _remove_empty(stabs, :rows)
812-
logs = _Flint_matrix_to_Julia_T_matrix(logicals_matrix(S), Int)
837+
logs = _Flint_matrix_to_Julia_T_matrix(logicals_matrix(S)[:, (which == :X ? (1:n) : (n + 1:2n))], Int)
838+
logs = _remove_empty(logs, :rows)
813839
operators_to_reduce = vcat(stabs, logs)
814-
check_against = permutedims(logs[:, [n + 1:2n; 1:n]])
815-
curr_l_bound = S.l_bound
840+
check_against = _Flint_matrix_to_Julia_T_matrix(logicals_matrix(S)[:, (which == :X ? (n + 1:2n) : (1:n))], Int)
841+
check_against = permutedims(_remove_empty(check_against, :rows))
842+
curr_l_bound = S.l_bound_dx
816843
verbose && println("Starting lower bound: $curr_l_bound")
817844

818845
# this is done in the constructor but the logical is not stored at the time
819846
# so must redo here
820847
mat = _rref_no_col_swap_binary(operators_to_reduce)
821-
anti = mat * check_against .% order(field(S))
848+
anti = mat * check_against .% 2
822849
log_locations = findall(!iszero(anti[i, :]) for i in axes(anti, 1))
823-
curr_u_bound, index = findmin(row_wts_symplectic(mat[log_locations, :]))
850+
curr_u_bound, index = findmin(count(!iszero, mat[log_locations[i], :]) for i in eachindex(log_locations))
824851
found = mat[log_locations[index], :]
825852
verbose && println("Starting upper bound: $curr_u_bound")
826-
elseif which == :X
827-
verbose && println("Bounding the X-distance")
828-
stabs = _Flint_matrix_to_Julia_T_matrix(stabilizers(S), Int)
829-
_rref_no_col_swap_binary!(stabs)
830-
stabs = _remove_empty(stabs, :rows)
831-
logs = _Flint_matrix_to_Julia_T_matrix(reduce(vcat, [log[1][:, 1:n] for log in S.logicals]), Int)
832-
operators_to_reduce = vcat(stabs, logs)
833-
check_against = permutedims(reduce(vcat, [log[2][:, n + 1:end] for log in S.logicals])[:, [n + 1:2n; 1:n]])
834-
curr_l_bound = S.l_bound_dx
835-
verbose && println("Starting lower bound: $curr_l_bound")
853+
end
836854

837-
# this is done in the constructor but the logical is not stored at the time
838-
# so must redo here
839-
mat = _rref_no_col_swap_binary(operators_to_reduce)
840-
anti = mat * check_against
841-
curr_u_bound, index = findmin(row_wts_symplectic(mat[findall(!iszero(anti[i, :]) for i in axes(anti, 1)), :]))
842-
found = operators_to_reduce[index, :]
843-
verbose && println("Starting upper bound: $curr_u_bound")
855+
uppers, founds = if which == :full
856+
_RIS_bound_loop_symp!(operators_to_reduce, check_against, curr_l_bound,
857+
curr_u_bound, found, max_iters, n, verbose)
844858
else
845-
verbose && println("Bounding the Z-distance")
846-
stabs = _Flint_matrix_to_Julia_T_matrix(stabilizers(S), Int)
847-
_rref_no_col_swap_binary!(stabs)
848-
stabs = _remove_empty(stabs, :rows)
849-
logs = _Flint_matrix_to_Julia_T_matrix(reduce(vcat, [log[2][:, n + 1:end] for log in S.logicals]), Int)
850-
operators_to_reduce = vcat(stabs, logs)
851-
check_against = permutedims(reduce(vcat, [log[1][:, 1:n] for log in S.logicals]))
852-
curr_l_bound = S.l_bound_dx
853-
verbose && println("Starting lower bound: $curr_l_bound")
854-
855-
# this is done in the constructor but the logical is not stored at the time
856-
# so must redo here
857-
mat = _rref_no_col_swap_binary(operators_to_reduce)
858-
anti = mat * check_against
859-
curr_u_bound, index = findmin(row_wts_symplectic(mat[findall(!iszero(anti[i, :]) for i in axes(anti, 1)), :]))
860-
found = operators_to_reduce[index, :]
861-
verbose && println("Starting upper bound: $curr_u_bound")
859+
_RIS_bound_loop!(operators_to_reduce, check_against, curr_l_bound,
860+
curr_u_bound, found, max_iters, n, verbose)
862861
end
863-
864-
uppers, founds = _RIS_bound_loop!(operators_to_reduce, check_against, curr_l_bound,
865-
curr_u_bound, found, max_iters, n, verbose)
866862
loc = argmin(uppers)
867863
if which == :full
868864
S.u_bound = uppers[loc]
@@ -921,7 +917,7 @@ function random_information_set_minimum_distance_bound!(::HasGauges, ::IsNotCSS,
921917
verbose && println("Starting upper bound: $curr_u_bound")
922918
end
923919

924-
uppers, founds = _RIS_bound_loop!(operators_to_reduce, check_against, curr_l_bound,
920+
uppers, founds = _RIS_bound_loop_symp!(operators_to_reduce, check_against, curr_l_bound,
925921
curr_u_bound, found, max_iters, n, verbose)
926922
loc = argmin(uppers)
927923
if dressed
@@ -958,7 +954,7 @@ function random_information_set_minimum_distance_bound!(::HasNoGauges, ::IsNotCS
958954
found = operators_to_reduce[index, :]
959955
verbose && println("Starting upper bound: $curr_u_bound")
960956

961-
uppers, founds = _RIS_bound_loop!(operators_to_reduce, check_against, curr_l_bound,
957+
uppers, founds = _RIS_bound_loop_symp!(operators_to_reduce, check_against, curr_l_bound,
962958
curr_u_bound, found, max_iters, n, verbose)
963959
loc = argmin(uppers)
964960
S.u_bound = uppers[loc]
@@ -973,7 +969,7 @@ end
973969

974970
# end
975971

976-
function _RIS_bound_loop!(operators_to_reduce, check_against, curr_l_bound::Int, curr_u_bound::Int, found, max_iters::Int, n::Int, verbose::Bool)
972+
function _RIS_bound_loop_symp!(operators_to_reduce, check_against, curr_l_bound::Int, curr_u_bound::Int, found, max_iters::Int, n::Int, verbose::Bool)
977973
num_thrds = Threads.nthreads()
978974
verbose && println("Detected $num_thrds threads.")
979975

@@ -982,16 +978,20 @@ function _RIS_bound_loop!(operators_to_reduce, check_against, curr_l_bound::Int,
982978
founds = [found for _ in 1:num_thrds]
983979
thread_load = Int(floor(max_iters / num_thrds))
984980
remaining = max_iters - thread_load * num_thrds
985-
Threads.@threads for t in 1:num_thrds
986-
log_test = zeros(Int, size(operators_to_reduce, 1), size(check_against, 2))
981+
# Threads.@threads for t in 1:num_thrds
982+
for t in 1:num_thrds
983+
orig_ops = deepcopy(operators_to_reduce)
984+
log_test = zeros(Int, size(orig_ops, 1), size(check_against, 2))
985+
perm_ops = similar(orig_ops)
986+
ops = similar(orig_ops)
987+
perm = collect(1:n)
987988
for _ in 1:(thread_load + (t <= remaining ? 1 : 0))
988989
if flag[]
989-
perm = shuffle(1:n)
990-
perm2 = [perm; perm .+ n]
991-
perm_ops = operators_to_reduce[:, perm2]
990+
shuffle!(perm)
991+
_col_permutation_symp!(perm_ops, orig_ops, perm)
992992
# modifying this in place is not thread safe (apparently)
993993
perm_ops = _rref_no_col_swap_binary(perm_ops)
994-
ops = perm_ops[:, invperm(perm2)]
994+
_col_permutation_symp!(ops, perm_ops, invperm(perm))
995995
LinearAlgebra.mul!(log_test, ops, check_against)
996996
for i in axes(log_test, 1)
997997
# then ops[i, :] is a logical
@@ -1019,3 +1019,53 @@ function _RIS_bound_loop!(operators_to_reduce, check_against, curr_l_bound::Int,
10191019

10201020
return uppers, founds
10211021
end
1022+
1023+
function _RIS_bound_loop!(operators_to_reduce, check_against, curr_l_bound::Int, curr_u_bound::Int, found, max_iters::Int, n::Int, verbose::Bool)
1024+
num_thrds = Threads.nthreads()
1025+
verbose && println("Detected $num_thrds threads.")
1026+
1027+
flag = Threads.Atomic{Bool}(true)
1028+
uppers = [curr_u_bound for _ in 1:num_thrds]
1029+
founds = [found for _ in 1:num_thrds]
1030+
thread_load = Int(floor(max_iters / num_thrds))
1031+
remaining = max_iters - thread_load * num_thrds
1032+
Threads.@threads for t in 1:num_thrds
1033+
orig_ops = deepcopy(operators_to_reduce)
1034+
log_test = zeros(Int, size(orig_ops, 1), size(check_against, 2))
1035+
perm_ops = similar(orig_ops)
1036+
ops = similar(orig_ops)
1037+
perm = collect(1:n)
1038+
for _ in 1:(thread_load + (t <= remaining ? 1 : 0))
1039+
if flag[]
1040+
shuffle!(perm)
1041+
_col_permutation!(perm_ops, orig_ops, perm)
1042+
# modifying this in place is not thread safe (apparently)
1043+
_rref_no_col_swap_binary!(perm_ops)
1044+
_col_permutation!(ops, perm_ops, invperm(perm))
1045+
LinearAlgebra.mul!(log_test, ops, check_against)
1046+
for i in axes(log_test, 1)
1047+
# then ops[i, :] is a logical
1048+
if any(isodd, log_test[i, :])
1049+
w = 0
1050+
@inbounds for j in 1:n
1051+
isodd(ops[i, j]) && (w += 1;)
1052+
end
1053+
1054+
if uppers[t] > w
1055+
uppers[t] = w
1056+
founds[t] .= ops[i, :]
1057+
verbose && println("Adjusting (thread's local) upper bound: $w")
1058+
if curr_l_bound == w
1059+
verbose && println("Found a logical that matched the lower bound of $curr_l_bound")
1060+
Threads.atomic_cas!(flag, true, false)
1061+
break
1062+
end
1063+
end
1064+
end
1065+
end
1066+
end
1067+
end
1068+
end
1069+
1070+
return uppers, founds
1071+
end

src/utils.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,29 @@ end
823823
_rref_symp_col_swap!(A::CTMatrixTypes) = _rref_symp_col_swap!(A, axes(A, 1), axes(A, 2))
824824
_rref_symp_col_swap(A::CTMatrixTypes) = (B = deepcopy(A); return _rref_symp_col_swap!(B);)
825825

826+
function _col_permutation!(X::Matrix{T}, A::Matrix{T}, p::AbstractVector{Int}) where T
827+
length(p) == size(A, 2) || error()
828+
size(X) == size(A) || error()
829+
for i in axes(X, 1)
830+
for j in axes(X, 2)
831+
X[i, j] = A[i, p[j]]
832+
end
833+
end
834+
return nothing
835+
end
836+
837+
function _col_permutation_symp!(X::Matrix{T}, A::Matrix{T}, p::AbstractVector{Int}) where T
838+
n = length(p)
839+
2n == size(A, 2) || error()
840+
size(X) == size(A) || error()
841+
for i in axes(X, 1)
842+
for j in axes(X, 2)
843+
X[i, j] = A[i, p[mod1(j, n)]]
844+
end
845+
end
846+
return nothing
847+
end
848+
826849
function digits_to_int(x::Vector{Int}, base::Int=2)
827850
res = 0
828851
for digit in x

0 commit comments

Comments
 (0)