Skip to content

Commit 99ac78a

Browse files
authored
Ensure contiguous occupations (#1189)
1 parent e2672b6 commit 99ac78a

File tree

4 files changed

+28
-7
lines changed

4 files changed

+28
-7
lines changed

src/occupation.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ end
3030
function compute_occupation(basis::PlaneWaveBasis{T}, eigenvalues::AbstractVector, εF::Number;
3131
temperature=basis.model.temperature,
3232
smearing=basis.model.smearing) where {T}
33+
# Check that eigenvalues are increasing monotonically, for contiguous occupations
34+
if !all(all(diff(εk) .≥ -eps(T)) for εk in eigenvalues)
35+
error("Eigenvalues should be monotonically increasing.")
36+
end
37+
3338
# This is needed to get the right behaviour for special floating-point types
3439
# such as intervals.
3540
inverse_temperature = iszero(temperature) ? T(Inf) : 1/temperature
@@ -207,3 +212,18 @@ function check_full_occupation(basis::PlaneWaveBasis, occupation)
207212
all(occ_k .== filled_occ) || error("Only full occupation is supported, but $occ_k has partial occupation.")
208213
end
209214
end
215+
216+
"""
217+
Return ranges of occupied elements based on a given occupation threshold
218+
"""
219+
function occupied_empty_masks(occupation, occupation_threshold)
220+
n_occ = map(occupation) do occ
221+
n = count(occ_i -> abs(occ_i) > occupation_threshold, occ)
222+
# Check that all occupied elements are contiguous, otherwise range is wrong
223+
@assert all(occ[1:n] .> occupation_threshold)
224+
n
225+
end
226+
mask_occ = [1:n_occ[ik] for ik in 1:length(occupation)]
227+
mask_empty = [(n_occ[ik] + 1):length(occupation[ik]) for ik in 1:length(occupation)]
228+
(; mask_occ, mask_empty)
229+
end

src/orbitals.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ using Random # Used to have a generic API for CPU and GPU computations alike: s
55
# threshold is a parameter to distinguish between states we want to keep and the
66
# others when using temperature. It is set to 0.0 by default, to treat with insulators.
77
function select_occupied_orbitals(basis, ψ, occupation; threshold=0.0)
8-
N = [something(findlast(x -> x > threshold, occk), 0) for occk in occupation]
9-
selected_ψ = [@view ψk[:, 1:N[ik]] for (ik, ψk) in enumerate(ψ)]
10-
selected_occ = [ occk[1:N[ik]] for (ik, occk) in enumerate(occupation)]
8+
mask_occ = occupied_empty_masks(occupation, threshold).mask_occ
9+
selected_ψ = [@view ψk[:, mask_occ[ik]] for (ik, ψk) in enumerate(ψ)]
10+
selected_occ = [ occk[mask_occ[ik]] for (ik, occk) in enumerate(occupation)]
1111

1212
# if we have an insulator, sanity check that the orbitals we kept are the
1313
# occupied ones

src/response/chi0.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -409,9 +409,7 @@ to the Hamiltonian change `δH` represented by the matrix-vector products `δHψ
409409
# We then use the extra information we have from these additional bands,
410410
# non-necessarily converged, to split the Sternheimer_solver with a Schur
411411
# complement.
412-
occ_thresh = occupation_threshold
413-
mask_occ = map(occk -> findall(occnk -> abs(occnk) occ_thresh, occk), occupation)
414-
mask_extra = map(occk -> findall(occnk -> abs(occnk) < occ_thresh, occk), occupation)
412+
(mask_occ, mask_extra) = occupied_empty_masks(occupation, occupation_threshold)
415413

416414
ψ_occ = [ψ[ik][:, maskk] for (ik, maskk) in enumerate(mask_occ)]
417415
ψ_extra = [ψ[ik][:, maskk] for (ik, maskk) in enumerate(mask_extra)]
@@ -563,7 +561,7 @@ function construct_bandtol(Bandtol::Type, basis::PlaneWaveBasis, ψ, occupation:
563561
Ω = basis.model.unit_cell_volume
564562
Ng = prod(basis.fft_size)
565563
Nk = length(basis.kpoints)
566-
mask_occ = map(ok -> findall(onk -> abs(onk) occupation_threshold, ok), occupation)
564+
mask_occ = occupied_empty_masks(occupation, occupation_threshold).mask_occ
567565

568566
# Including k-points the expression (3.11) in 2505.02319 becomes
569567
# with Φk = (ψ_{1,k} … ψ_{n,k})_k (Concatenation of all occupied orbitals for this k)

src/supercell.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ function cell_to_supercell(scfres::NamedTuple)
112112
basis_supercell = cell_to_supercell(basis)
113113
ψ_supercell = [cell_to_supercell(ψ, basis, basis_supercell)]
114114
eigs_supercell = [reduce(vcat, scfres_unfold.eigenvalues)]
115+
perms = [sortperm(eigs_supercell[ik]) for ik = 1:length(eigs_supercell)]
116+
ψ_supercell = [ψ_supercell[ik][:, perms[ik]] for ik = 1:length(ψ_supercell)]
117+
eigs_supercell = [eigs_supercell[ik][perms[ik]] for ik = 1:length(eigs_supercell)]
115118
occ_supercell = compute_occupation(basis_supercell, eigs_supercell, scfres.εF).occupation
116119
ρ_supercell = compute_density(basis_supercell, ψ_supercell, occ_supercell;
117120
scfres.occupation_threshold)

0 commit comments

Comments
 (0)