Skip to content

Commit b7f820f

Browse files
committed
improved qdistrnd: feature parity with GAP for GF(2)
1 parent ebef4df commit b7f820f

File tree

2 files changed

+72
-34
lines changed

2 files changed

+72
-34
lines changed

src/Quantum/weight_dist.jl

Lines changed: 68 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -631,8 +631,20 @@ function qdistrnd(S::T, type::Symbol = :both; dressed::Bool = true, iters::Int =
631631
elseif GaugeTrait(T) == HasGauges() && CSSTrait(T) == IsCSS() && type == :both && dressed
632632
min(_qdistrnd(IsCSS(), S, :X, dressed, iters, d_lower_bound),
633633
_qdistrnd(IsCSS(), S, :Z, dressed, iters, d_lower_bound))
634-
elseif CSSTrait(T) == IsNotCSS() || type == :both
634+
elseif CSSTrait(T) == IsNotCSS()
635635
_qdistrnd(IsNotCSS(), S, type, dressed, iters, d_lower_bound)
636+
elseif type == :both
637+
dx, logx = _qdistrnd(IsCSS(), S, :X, dressed, iters, d_lower_bound)
638+
if dx <= d_lower_bound
639+
dx, logx
640+
else
641+
dz, logz = _qdistrnd(IsCSS(), S, :Z, dressed, iters, d_lower_bound)
642+
if dx <= dz
643+
dx, logx
644+
else
645+
dz, logz
646+
end
647+
end
636648
else # IsCSS() and type in (:X, :Z)
637649
_qdistrnd(IsCSS(), S, type, dressed, iters, d_lower_bound)
638650
end
@@ -644,28 +656,36 @@ function _qdistrnd(::IsNotCSS, S::T, type::Symbol, dressed::Bool, iters::Int,
644656

645657
n = length(S)
646658
k = dimension(S)
647-
stabs = _Flint_matrix_to_Julia_bool_matrix(stabilizers(S))
659+
stabs = _Flint_matrix_to_Julia_T_matrix(stabilizers(S), Bool)
648660
_rref_no_col_swap!(stabs)
649661
stabs = _remove_empty(stabs, :rows)
650662
nstabs = size(stabs, 1)
651-
logs = _Flint_matrix_to_Julia_bool_matrix(logicals_matrix(S))
652-
bound = fill(2n, Threads.nthreads())
663+
logs = _Flint_matrix_to_Julia_T_matrix(logicals_matrix(S), Bool)
664+
operators = vcat(stabs, logs)
665+
anticommuting_logs = _Flint_matrix_to_Julia_T_matrix(logicals_matrix(S), Bool)
666+
anticommuting_logs_symp_T = permutedims(anticommuting_logs[:, [n + 1: 2n; 1:n]])
667+
668+
temp = logs[:, 1:n] + logs[:, 1 + n:2n]
669+
bestlogwt, bestlogindex = findmin(count(!iszero, temp[i, :]) for i in 1:k)
670+
bound = fill(bestlogwt, Threads.nthreads())
671+
saved_logs = fill(logs[bestlogindex, :], Threads.nthreads())
653672
Threads.@threads for iter in 1:iters
654673
tid = Threads.threadid()
655674
perm = shuffle(1:n)
656675
perm2 = [perm; perm .+ n]
657-
s = stabs[:, perm2]
658-
l = logs[:, perm2]
659-
_rref_no_col_swap!(s)
660-
pivots = [findfirst(view(s, i, :)) for i in 1:nstabs]
661-
for (j, pivot) in enumerate(pivots)
662-
for i in 1:2k
663-
if l[i, pivot]
664-
l[i, :] .⊻= s[j, :]
676+
permops = operators[:, perm2]
677+
_rref_no_col_swap!(permops)
678+
logtest = permops * anticommuting_logs_symp_T[perm2, :] .% 0x02
679+
for i in axes(logtest, 1)
680+
if !all(iszero, logtest[i, :]) # then permops[i, :] is a logical
681+
temp = permops[i, 1:n] + permops[i, 1 + n:2n]
682+
wt = count(!iszero, temp)
683+
if wt < bound[tid]
684+
bound[tid] = wt
685+
saved_logs[tid] .= permops[i, invperm(perm2)]
665686
end
666687
end
667688
end
668-
bound[tid] = min(bound[tid], minimum(count(view(l, i, :)) for i in 1:2k))
669689
if bound[tid] <= d_lower_bound
670690
if bound[tid] < d_lower_bound
671691
@warn "The given `d_lower_bound` is greater than the distance."
@@ -674,8 +694,11 @@ function _qdistrnd(::IsNotCSS, S::T, type::Symbol, dressed::Bool, iters::Int,
674694
end
675695
end
676696
d_upper_bound = minimum(bound)
677-
678-
return d_upper_bound
697+
best_log = matrix(field(S), permutedims(saved_logs[argmin(bound)]))
698+
699+
is_logical(S, best_log) || error("Found something that isn't a logical. There is a bug in qdistrnd.")
700+
701+
return d_upper_bound, best_log
679702
end
680703

681704
function _qdistrnd(::IsCSS, S::T, type::Symbol, dressed::Bool, iters::Int,
@@ -685,7 +708,7 @@ function _qdistrnd(::IsCSS, S::T, type::Symbol, dressed::Bool, iters::Int,
685708
k = dimension(S)
686709

687710
# for dressed distance of a subsystem code, need to include gauges as stabilizers
688-
stabs = if GaugeTrait(T) == HasGauges() && dressed
711+
stabs = _Flint_matrix_to_Julia_T_matrix(if GaugeTrait(T) == HasGauges() && dressed
689712
if type == :X
690713
G = _remove_empty(gauges_matrix(S)[:, 1:n], :rows)
691714
vcat(X_stabilizers(S), G)
@@ -695,34 +718,37 @@ function _qdistrnd(::IsCSS, S::T, type::Symbol, dressed::Bool, iters::Int,
695718
end
696719
else
697720
type == :X ? X_stabilizers(S) : Z_stabilizers(S)
698-
end |> _Flint_matrix_to_Julia_bool_matrix
721+
end, Bool)
699722
_rref_no_col_swap!(stabs)
700723
stabs = _remove_empty(stabs, :rows)
701724
nstabs = size(stabs, 1)
702725

703-
logs = if type == :X
704-
_remove_empty(logicals_matrix(S)[:, 1:n], :rows)
705-
else
706-
_remove_empty(logicals_matrix(S)[:, n + 1:2n], :rows)
707-
end |> _Flint_matrix_to_Julia_bool_matrix
726+
logs = _Flint_matrix_to_Julia_T_matrix(_remove_empty(
727+
logicals_matrix(S)[:, (1:n) .+ (type == :X ? 0 : n)], :rows), Bool)
728+
operators = vcat(stabs, logs)
729+
anticommuting_logs = _Flint_matrix_to_Julia_T_matrix(_remove_empty(
730+
logicals_matrix(S)[:, (1:n) .+ (type == :Z ? 0 : n)], :rows), Bool)
731+
anticommuting_logs_T = permutedims(anticommuting_logs)
708732

709733
# Basic algorithm for CSS codes over GF(2):
710-
bound = fill(n, Threads.nthreads())
734+
bestlogwt, bestlogindex = findmin(count(logs[i, :]) for i in 1:k)
735+
bound = fill(bestlogwt, Threads.nthreads())
736+
saved_logs = fill(logs[bestlogindex, :], Threads.nthreads())
711737
Threads.@threads for iter in 1:iters
712738
tid = Threads.threadid()
713739
perm = shuffle(1:n)
714-
s = stabs[:, perm]
715-
l = logs[:, perm]
716-
_rref_no_col_swap!(s)
717-
pivots = [findfirst(view(s, i, :)) for i in 1:nstabs]
718-
for (j, pivot) in enumerate(pivots)
719-
for i in 1:k
720-
if l[i, pivot]
721-
l[i, :] .⊻= s[j, :]
740+
permops = operators[:, perm]
741+
_rref_no_col_swap!(permops)
742+
logtest = permops * anticommuting_logs_T[perm, :] .% 0x02
743+
for i in axes(logtest, 1)
744+
if !all(iszero, logtest[i, :]) # then permops[i, :] is a logical
745+
wt = count(!iszero, permops[i, :])
746+
if wt < bound[tid]
747+
bound[tid] = wt
748+
saved_logs[tid] .= permops[i, invperm(perm)]
722749
end
723750
end
724751
end
725-
bound[tid] = min(bound[tid], minimum(count(view(l, i, :)) for i in 1:k))
726752
if bound[tid] <= d_lower_bound
727753
if bound[tid] < d_lower_bound
728754
@warn "The given `d_lower_bound` is greater than the distance."
@@ -731,6 +757,14 @@ function _qdistrnd(::IsCSS, S::T, type::Symbol, dressed::Bool, iters::Int,
731757
end
732758
end
733759
d_upper_bound = minimum(bound)
734-
735-
return d_upper_bound
760+
best_log = matrix(field(S), permutedims(saved_logs[argmin(bound)]))
761+
if type == :X
762+
best_log = hcat(best_log, zero_matrix(field(S), 1, n))
763+
else
764+
best_log = hcat(zero_matrix(field(S), 1, n), best_log)
765+
end
766+
767+
is_logical(S, best_log) || error("Found something that isn't a logical. There is a bug in qdistrnd.")
768+
769+
return d_upper_bound, best_log
736770
end

src/utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,10 @@ function _Flint_matrix_to_Julia_bool_matrix(A::CTMatrixTypes)
294294
Matrix{Bool}(_Flint_matrix_to_Julia_int_matrix(A))
295295
end
296296

297+
function _Flint_matrix_to_Julia_T_matrix(A::CTMatrixTypes, ::Type{T}) where T <: Number
298+
Matrix{T}(_Flint_matrix_to_Julia_int_matrix(A))
299+
end
300+
297301
# function _Flint_matrix_to_Julia_int_vector(A)
298302
# # (nr == 1 || nc == 1) || throw(ArgumentError("Cannot cast matrix to vector"))
299303
# return _Flint_matrix_element_to_Julia_int(A, 1, 1)

0 commit comments

Comments
 (0)