Skip to content

Commit 09d5e7a

Browse files
committed
QDistRnd working properly for CSS subsystem, some util improvements
1 parent 85caafe commit 09d5e7a

File tree

2 files changed

+86
-148
lines changed

2 files changed

+86
-148
lines changed

src/Quantum/weight_dist.jl

Lines changed: 46 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -657,124 +657,49 @@ function random_information_set_minimum_distance_bound!(::HasGauges, ::IsCSS, ::
657657
# this is a CSS subsystem code
658658

659659
n = S.n
660-
if dressed
661-
if which == :full
662-
verbose && println("Bounding the full dressed distance")
663-
stabs = _Flint_matrix_to_Julia_T_matrix(stabilizers(S), Int)
664-
_rref_no_col_swap_binary!(stabs)
665-
stabs = _remove_empty(stabs, :rows)
666-
gauges = _Flint_matrix_to_Julia_T_matrix(gauges_matrix(S), Int)
667-
logs = _Flint_matrix_to_Julia_T_matrix(logicals_matrix(S), Int)
668-
operators_to_reduce = vcat(stabs, gauges, logs)
669-
check_against = permutedims(logs[:, [n + 1:2n; 1:n]])
670-
curr_l_bound = S.l_bound_dressed
671-
verbose && println("Starting lower bound: $curr_l_bound")
672-
673-
# this is done in the constructor but the logical is not stored at the time
674-
# so must redo here
675-
mat = _rref_no_col_swap_binary(operators_to_reduce)
676-
anti = mat * check_against
677-
curr_u_bound, index = findmin(row_wts_symplectic(mat[findall(!iszero(anti[i, :]) for i in axes(anti, 1)), :]))
678-
found = operators_to_reduce[index, :]
679-
verbose && println("Starting upper bound: $curr_u_bound")
680-
elseif which == :X
681-
verbose && println("Bounding the dressed X-distance")
682-
stabs = _Flint_matrix_to_Julia_T_matrix(stabilizers(S), Int)
683-
_rref_no_col_swap_binary!(stabs)
684-
stabs = _remove_empty(stabs, :rows)
685-
gauges = _Flint_matrix_to_Julia_T_matrix(reduce(vcat, [log[1][:, 1:n] for log in S.gauge_ops]), Int)
686-
logs = _Flint_matrix_to_Julia_T_matrix(reduce(vcat, [log[1][:, 1:n] for log in S.logicals]), Int)
687-
operators_to_reduce = vcat(stabs, gauges, logs)
688-
check_against = permutedims(reduce(vcat, [log[2][:, n + 1:end] for log in S.logicals])[:, [n + 1:2n; 1:n]])
689-
curr_l_bound = S.l_bound_dx_dressed
690-
verbose && println("Starting lower bound: $curr_l_bound")
691-
692-
# this is done in the constructor but the logical is not stored at the time
693-
# so must redo here
694-
mat = _rref_no_col_swap_binary(operators_to_reduce)
695-
anti = mat * check_against
696-
curr_u_bound, index = findmin(row_wts_symplectic(mat[findall(!iszero(anti[i, :]) for i in axes(anti, 1)), :]))
697-
found = operators_to_reduce[index, :]
698-
verbose && println("Starting upper bound: $curr_u_bound")
660+
if which == :full
661+
upperx, foundx = random_information_set_minimum_distance_bound!(HasGauges(), IsCSS(), HasLogicals(), S, :X, dressed, max_iters, verbose)
662+
upperz, foundz = random_information_set_minimum_distance_bound!(HasGauges(), IsCSS(), HasLogicals(), S, :Z, dressed, max_iters, verbose)
663+
if upperx <= upperz
664+
if dressed
665+
S.u_bound_dressed = min(upperx, S.u_bound_dressed)
666+
else
667+
S.u_bound_bare = min(upperx, S.u_bound_bare)
668+
end
669+
return upperx, foundx
699670
else
700-
verbose && println("Bounding the dressed Z-distance")
701-
stabs = _Flint_matrix_to_Julia_T_matrix(stabilizers(S), Int)
702-
_rref_no_col_swap_binary!(stabs)
703-
stabs = _remove_empty(stabs, :rows)
704-
gauges = _Flint_matrix_to_Julia_T_matrix(reduce(vcat, [log[2][:, n + 1:end] for log in S.gauge_ops]), Int)
705-
logs = _Flint_matrix_to_Julia_T_matrix(reduce(vcat, [log[2][:, n + 1:end] for log in S.logicals]), Int)
706-
operators_to_reduce = vcat(stabs, gauges, logs)
707-
check_against = permutedims(reduce(vcat, [log[1][:, 1:n] for log in S.logicals]))
708-
curr_l_bound = S.l_bound_dz_dressed
709-
verbose && println("Starting lower bound: $curr_l_bound")
710-
711-
# this is done in the constructor but the logical is not stored at the time
712-
# so must redo here
713-
mat = _rref_no_col_swap_binary(operators_to_reduce)
714-
anti = mat * check_against
715-
curr_u_bound, index = findmin(row_wts_symplectic(mat[findall(!iszero(anti[i, :]) for i in axes(anti, 1)), :]))
716-
found = operators_to_reduce[index, :]
717-
verbose && println("Starting upper bound: $curr_u_bound")
671+
if dressed
672+
S.u_bound_dressed = min(upperz, S.u_bound_dressed)
673+
else
674+
S.u_bound_bare = min(upperz, S.u_bound_bare)
675+
end
676+
return upperz, foundz
718677
end
719678
else
720-
if which == :full
721-
verbose && println("Bounding the full bare distance")
722-
stabs = _Flint_matrix_to_Julia_T_matrix(stabilizers(S), Int)
723-
_rref_no_col_swap_binary!(stabs)
724-
stabs = _remove_empty(stabs, :rows)
725-
logs = _Flint_matrix_to_Julia_T_matrix(logicals_matrix(S), Int)
726-
operators_to_reduce = vcat(stabs, logs)
727-
check_against = permutedims(logs[:, [n + 1:2n; 1:n]])
728-
curr_l_bound = S.l_bound_bare
729-
verbose && println("Starting lower bound: $curr_l_bound")
730-
731-
# this is done in the constructor but the logical is not stored at the time
732-
# so must redo here
733-
mat = _rref_no_col_swap_binary(operators_to_reduce)
734-
anti = mat * check_against
735-
curr_u_bound, index = findmin(row_wts_symplectic(mat[findall(!iszero(anti[i, :]) for i in axes(anti, 1)), :]))
736-
found = operators_to_reduce[index, :]
737-
verbose && println("Starting upper bound: $curr_u_bound")
738-
elseif which == :X
739-
verbose && println("Bounding the bare X-distance")
740-
stabs = _Flint_matrix_to_Julia_T_matrix(stabilizers(S), Int)
741-
_rref_no_col_swap_binary!(stabs)
742-
stabs = _remove_empty(stabs, :rows)
743-
logs = _Flint_matrix_to_Julia_T_matrix(reduce(vcat, [log[1][:, 1:n] for log in S.logicals]), Int)
744-
operators_to_reduce = vcat(stabs, logs)
745-
check_against = permutedims(reduce(vcat, [log[2][:, n + 1:end] for log in S.logicals])[:, [n + 1:2n; 1:n]])
746-
curr_l_bound = S.l_bound_dx_bare
747-
verbose && println("Starting lower bound: $curr_l_bound")
748-
749-
# this is done in the constructor but the logical is not stored at the time
750-
# so must redo here
751-
mat = _rref_no_col_swap_binary(operators_to_reduce)
752-
anti = mat * check_against
753-
curr_u_bound, index = findmin(row_wts_symplectic(mat[findall(!iszero(anti[i, :]) for i in axes(anti, 1)), :]))
754-
found = operators_to_reduce[index, :]
755-
verbose && println("Starting upper bound: $curr_u_bound")
679+
stabs = _Flint_matrix_to_Julia_T_matrix(stabilizers(S)[:, (which == :X ? (1:n) : (n + 1:2n))], UInt8)
680+
if dressed
681+
gauges = _Flint_matrix_to_Julia_T_matrix(gauge_operators_matrix(S)[:, (which == :X ? (1:n) : (n + 1:2n))], UInt8)
682+
stabs = [stabs; gauges]
683+
end
684+
_rref_no_col_swap_binary!(stabs)
685+
stabs = _remove_empty(stabs, :rows)
686+
logs = _Flint_matrix_to_Julia_T_matrix(logicals_matrix(S)[:, (which == :X ? (1:n) : (n + 1:2n))], UInt8)
687+
logs = _remove_empty(logs, :rows)
688+
operators_to_reduce = vcat(stabs, logs)
689+
check_against = _Flint_matrix_to_Julia_T_matrix(logicals_matrix(S)[:, (which == :X ? (n + 1:2n) : (1:n))], UInt8)
690+
check_against = permutedims(_remove_empty(check_against, :rows))
691+
curr_l_bound = if dressed
692+
which == :X ? S.l_bound_dx_dressed : S.l_bound_dz_dressed
756693
else
757-
verbose && println("Bounding the bare Z-distance")
758-
stabs = _Flint_matrix_to_Julia_T_matrix(stabilizers(S), Int)
759-
_rref_no_col_swap_binary!(stabs)
760-
stabs = _remove_empty(stabs, :rows)
761-
logs = _Flint_matrix_to_Julia_T_matrix(reduce(vcat, [log[2][:, n + 1:end] for log in S.logicals]), Int)
762-
operators_to_reduce = vcat(stabs, logs)
763-
check_against = permutedims(reduce(vcat, [log[1][:, 1:n] for log in S.logicals]))
764-
curr_l_bound = S.l_bound_dx_bare
765-
verbose && println("Starting lower bound: $curr_l_bound")
766-
767-
# this is done in the constructor but the logical is not stored at the time
768-
# so must redo here
769-
mat = _rref_no_col_swap_binary(operators_to_reduce)
770-
anti = mat * check_against
771-
curr_u_bound, index = findmin(row_wts_symplectic(mat[findall(!iszero(anti[i, :]) for i in axes(anti, 1)), :]))
772-
found = operators_to_reduce[index, :]
773-
verbose && println("Starting upper bound: $curr_u_bound")
694+
which == :X ? S.l_bound_dx_bare : S.l_bound_dz_bare
774695
end
696+
verbose && println("Starting lower bound: $curr_l_bound")
697+
curr_u_bound, index = findmin(count(!iszero, logs[i, :]) for i in 1:size(logs, 1))
698+
found = logs[index, :]
699+
verbose && println("Starting upper bound: $curr_u_bound")
775700
end
776701

777-
uppers, founds = _RIS_bound_loop_symp!(operators_to_reduce, check_against, curr_l_bound,
702+
uppers, founds = _RIS_bound_loop!(operators_to_reduce, check_against, curr_l_bound,
778703
curr_u_bound, found, max_iters, n, verbose)
779704
loc = argmin(uppers)
780705
if dressed
@@ -795,7 +720,14 @@ function random_information_set_minimum_distance_bound!(::HasGauges, ::IsCSS, ::
795720
end
796721
end
797722
verbose && println("Ending $max_iters iterations with an upper bound of $(uppers[loc])")
798-
return uppers[loc], founds[loc]
723+
flint_mat_found = if which == :full
724+
matrix(field(S), permutedims(founds[loc]))
725+
elseif which == :X
726+
matrix(field(S), [permutedims(founds[loc]) zeros(Int, 1, n)])
727+
else
728+
matrix(field(S), [zeros(Int, 1, n) permutedims(founds[loc])])
729+
end
730+
return uppers[loc], flint_mat_found
799731
end
800732

801733
function random_information_set_minimum_distance_bound!(::HasNoGauges, ::IsCSS, ::HasLogicals,
@@ -979,7 +911,8 @@ function _RIS_bound_loop_symp!(operators_to_reduce, check_against, curr_l_bound:
979911
shuffle!(perm)
980912
_col_permutation_symp!(perm_ops, orig_ops, perm)
981913
# modifying this in place is not thread safe (apparently)
982-
perm_ops = _rref_no_col_swap_binary(perm_ops)
914+
# perm_ops = _rref_no_col_swap_binary(perm_ops)
915+
_rref_no_col_swap_binary!(perm_ops)
983916
_col_permutation_symp!(ops, perm_ops, invperm(perm))
984917
LinearAlgebra.mul!(log_test, ops, check_against)
985918
for i in axes(log_test, 1)

src/utils.jl

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -425,31 +425,31 @@ function _remove_empty(A::Union{CTMatrixTypes, Matrix{<: Number}, BitMatrix, Mat
425425
end
426426
end
427427

428-
function _rref_no_col_swap(M::CTMatrixTypes, row_range::UnitRange{Int}, col_range::UnitRange{Int})
428+
function _rref_no_col_swap(M::CTMatrixTypes, row_range::AbstractUnitRange{Int}, col_range::AbstractUnitRange{Int})
429429
A = deepcopy(M)
430430
_rref_no_col_swap!(A, row_range, col_range)
431431
return A
432432
end
433-
_rref_no_col_swap(M::CTMatrixTypes, row_range::Base.OneTo{Int}, col_range::Base.OneTo{Int}) = _rref_no_col_swap(M, 1:row_range.stop, 1:col_range.stop)
433+
_rref_no_col_swap(M::CTMatrixTypes, row_range::AbstractUnitRange{Int}, col_range::AbstractUnitRange{Int}) = _rref_no_col_swap(M, 1:row_range.stop, 1:col_range.stop)
434434
_rref_no_col_swap(M::CTMatrixTypes) = _rref_no_col_swap(M, axes(M, 1), axes(M, 2))
435435

436436
function _rref_no_col_swap_binary(A::Union{BitMatrix, Matrix{Bool}, Matrix{<: Integer}},
437-
row_range::UnitRange{Int} = 1:size(A, 1), col_range::UnitRange{Int} = 1:size(A, 2))
437+
row_range::AbstractUnitRange{Int} = 1:size(A, 1), col_range::AbstractUnitRange{Int} = 1:size(A, 2))
438438

439439
B = deepcopy(A)
440440
_rref_no_col_swap_binary!(B, row_range, col_range)
441441
return B
442442
end
443443

444-
function _rref_no_col_swap!(A::CTMatrixTypes, row_range::UnitRange{Int}, col_range::UnitRange{Int})
444+
function _rref_no_col_swap!(A::CTMatrixTypes, row_range::AbstractUnitRange{Int}, col_range::AbstractUnitRange{Int})
445445
# don't do anything to A if the range is empty
446446
isempty(row_range) && return nothing
447447
isempty(col_range) && return nothing
448448

449-
i = row_range.start
450-
j = col_range.start
451-
nr = row_range.stop
452-
nc = col_range.stop
449+
i = first(row_range)
450+
j = first(col_range)
451+
nr = last(row_range)
452+
nc = last(col_range)
453453
if Int(order(base_ring(A))) != 2
454454
while i <= nr && j <= nc
455455
# find first pivot
@@ -523,14 +523,14 @@ function _rref_no_col_swap!(A::CTMatrixTypes, row_range::UnitRange{Int}, col_ran
523523
end
524524

525525
function _rref_no_col_swap_binary!(A::Union{BitMatrix, Matrix{Bool}, Matrix{<: Integer}},
526-
row_range::UnitRange{Int} = 1:size(A, 1), col_range::UnitRange{Int} = 1:size(A, 2))
526+
row_range::AbstractUnitRange{Int} = 1:size(A, 1), col_range::AbstractUnitRange{Int} = 1:size(A, 2))
527527

528528
isempty(row_range) && return nothing
529529
isempty(col_range) && return nothing
530-
i = row_range.start
531-
j = col_range.start
532-
nr = row_range.stop
533-
nc = col_range.stop
530+
i = first(row_range)
531+
j = first(col_range)
532+
nr = last(row_range)
533+
nc = last(col_range)
534534
while i <= nr && j <= nc
535535
# find first pivot
536536
ind = 0
@@ -565,15 +565,15 @@ function _rref_no_col_swap_binary!(A::Union{BitMatrix, Matrix{Bool}, Matrix{<: I
565565
return nothing
566566
end
567567

568-
function _rref_col_swap(M::CTMatrixTypes, row_range::UnitRange{Int}, col_range::UnitRange{Int})
568+
function _rref_col_swap(M::CTMatrixTypes, row_range::AbstractUnitRange{Int}, col_range::AbstractUnitRange{Int})
569569
A = deepcopy(M)
570570
rnk, P = _rref_col_swap!(A, row_range, col_range)
571571
return rnk, A, P
572572
end
573-
_rref_col_swap(M::CTMatrixTypes, row_range::Base.OneTo{Int}, col_range::Base.OneTo{Int}) = _rref_col_swap(M, 1:row_range.stop, 1:col_range.stop)
573+
_rref_col_swap(M::CTMatrixTypes, row_range::AbstractUnitRange{Int}, col_range::AbstractUnitRange{Int}) = _rref_col_swap(M, 1:row_range.stop, 1:col_range.stop)
574574
_rref_col_swap(M::CTMatrixTypes) = _rref_col_swap(M, axes(M, 1), axes(M, 2))
575575

576-
function _rref_col_swap!(A::CTMatrixTypes, row_range::UnitRange{Int}, col_range::UnitRange{Int})
576+
function _rref_col_swap!(A::CTMatrixTypes, row_range::AbstractUnitRange{Int}, col_range::AbstractUnitRange{Int})
577577
# don't do anything to A if the range is empty, return rank 0 and missing permutation matrix
578578
isempty(row_range) && return 0, missing
579579
isempty(col_range) && return 0, missing
@@ -583,10 +583,10 @@ function _rref_col_swap!(A::CTMatrixTypes, row_range::UnitRange{Int}, col_range:
583583
nc_A = ncols(A)
584584

585585
rnk = 0
586-
i = row_range.start
587-
j = col_range.start
588-
nr = row_range.stop
589-
nc = col_range.stop
586+
i = first(row_range)
587+
j = first(col_range)
588+
nr = last(row_range)
589+
nc = last(col_range)
590590
if Int(order(base_ring(A))) != 2
591591
while i <= nr && j <= nc
592592
# find first pivot
@@ -693,7 +693,7 @@ function _rref_col_swap!(A::CTMatrixTypes, row_range::UnitRange{Int}, col_range:
693693
return rnk, P
694694
end
695695

696-
function _rref_symp_col_swap!(A::CTMatrixTypes, row_range::UnitRange{Int}, col_range::UnitRange{Int})
696+
function _rref_symp_col_swap!(A::CTMatrixTypes, row_range::AbstractUnitRange{Int}, col_range::AbstractUnitRange{Int})
697697
# don't do anything to A if the range is empty, return rank 0 and missing permutation matrix
698698
isempty(row_range) && return 0, missing
699699
isempty(col_range) && return 0, missing
@@ -703,10 +703,10 @@ function _rref_symp_col_swap!(A::CTMatrixTypes, row_range::UnitRange{Int}, col_r
703703
nc_A = ncols(A)
704704

705705
rnk = 0
706-
i = row_range.start
707-
j = col_range.start
708-
nr = row_range.stop
709-
nc = col_range.stop
706+
i = first(row_range)
707+
j = first(col_range)
708+
nr = last(row_range)
709+
nc = last(col_range)
710710
if Int(order(base_ring(A))) != 2
711711
while i <= nr && j <= nc
712712
# find first pivot
@@ -750,7 +750,7 @@ function _rref_symp_col_swap!(A::CTMatrixTypes, row_range::UnitRange{Int}, col_r
750750
ind != i && swap_rows!(A, ind, i)
751751

752752
# eliminate
753-
for k = row_range.start:nr
753+
for k = first(row_range):nr
754754
if k != i
755755
# do a manual loop here to reduce allocations
756756
d = A[k, j]
@@ -802,7 +802,7 @@ function _rref_symp_col_swap!(A::CTMatrixTypes, row_range::UnitRange{Int}, col_r
802802
ind != i && swap_rows!(A, ind, i)
803803

804804
# eliminate
805-
for k = row_range.start:nr
805+
for k = first(row_range):nr
806806
if k != i
807807
if isone(A[k, j])
808808
# do a manual loop here to reduce allocations
@@ -824,10 +824,10 @@ _rref_symp_col_swap!(A::CTMatrixTypes) = _rref_symp_col_swap!(A, axes(A, 1), axe
824824
_rref_symp_col_swap(A::CTMatrixTypes) = (B = deepcopy(A); return _rref_symp_col_swap!(B);)
825825

826826
function _col_permutation!(X::Matrix{T}, A::Matrix{T}, p::AbstractVector{Int}) where T
827-
length(p) == size(A, 2) || error()
828-
size(X) == size(A) || error()
829-
for i in axes(X, 1)
830-
for j in axes(X, 2)
827+
length(p) == size(A, 2) || throw(ArgumentError("`p` should have length `size(A, 2)`."))
828+
size(X) == size(A) || throw(ArgumentError("`X` and `A` should have the same shape."))
829+
for j in axes(X, 2)
830+
for i in axes(X, 1)
831831
X[i, j] = A[i, p[j]]
832832
end
833833
end
@@ -836,13 +836,18 @@ end
836836

837837
function _col_permutation_symp!(X::Matrix{T}, A::Matrix{T}, p::AbstractVector{Int}) where T
838838
n = length(p)
839-
2n == size(A, 2) || error()
840-
size(X) == size(A) || error()
841-
for i in axes(X, 1)
842-
for j in axes(X, 2)
839+
2n == size(A, 2) || throw(ArgumentError("`p` should have length `size(A, 2)/2`."))
840+
size(X) == size(A) || throw(ArgumentError("`X` and `A` should have the same shape."))
841+
for j in 1:n
842+
for i in axes(X, 1)
843843
X[i, j] = A[i, p[mod1(j, n)]]
844844
end
845845
end
846+
for j in 1:n
847+
for i in axes(X, 1)
848+
X[i, j + n] = A[i, p[n] + n]
849+
end
850+
end
846851
return nothing
847852
end
848853

0 commit comments

Comments
 (0)