Skip to content

Commit c4bb6d3

Browse files
committed
improve loop in RIS
1 parent 4eca25b commit c4bb6d3

File tree

1 file changed

+3
-42
lines changed

1 file changed

+3
-42
lines changed

src/Quantum/weight_dist.jl

Lines changed: 3 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -981,9 +981,10 @@ function _RIS_bound_loop!(operators_to_reduce, check_against, curr_l_bound::Int,
981981
uppers = [curr_u_bound for _ in 1:num_thrds]
982982
founds = [found for _ in 1:num_thrds]
983983
thread_load = Int(floor(max_iters / num_thrds))
984+
remaining = max_iters - thread_load * num_thrds
984985
Threads.@threads for t in 1:num_thrds
985986
log_test = zeros(Int, size(operators_to_reduce, 1), size(check_against, 2))
986-
for _ in 1:thread_load
987+
for _ in 1:(thread_load + (t <= remaining ? 1 : 0))
987988
if flag[]
988989
perm = shuffle(1:n)
989990
perm2 = [perm; perm .+ n]
@@ -1016,45 +1017,5 @@ function _RIS_bound_loop!(operators_to_reduce, check_against, curr_l_bound::Int,
10161017
end
10171018
end
10181019

1019-
# finish up the final couple
1020-
remaining = max_iters - thread_load * num_thrds
1021-
if !iszero(remaining)
1022-
upper_temp = n
1023-
found_temp::Vector{Int}
1024-
flag = true
1025-
log_test = zeros(Int, size(check_against, 2))
1026-
for _ in 1:remaining
1027-
if flag
1028-
perm = shuffle(1:n)
1029-
perm2 = [perm; perm .+ n]
1030-
perm_ops = operators_to_reduce[:, perm2]
1031-
_rref_no_col_swap_binary!(perm_ops)
1032-
ops = perm_ops[:, invperm(perm2)]
1033-
LinearAlgebra.mul!(log_test, ops, check_against)
1034-
for i in axes(log_test, 1)
1035-
# then ops[i, :] is a logical
1036-
if any(isodd, log_test[i, :])
1037-
w = 0
1038-
@inbounds for j in 1:n
1039-
(isodd(ops[i, j]) || isodd(ops[i, j + n])) && (w += 1;)
1040-
end
1041-
1042-
if upper_temp > w
1043-
upper_temp = w
1044-
found_temp .= ops[i, :]
1045-
verbose && println("Adjusting (thread's local) upper bound: $w")
1046-
if curr_l_bound == w
1047-
verbose && println("Found a logical that matched the lower bound of $curr_l_bound")
1048-
flag = false
1049-
break
1050-
end
1051-
end
1052-
end
1053-
end
1054-
end
1055-
end
1056-
return [uppers; upper_temp], [founds; found_temp]
1057-
else
1058-
return uppers, founds
1059-
end
1020+
return uppers, founds
10601021
end

0 commit comments

Comments
 (0)