Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/occupation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ end
function compute_occupation(basis::PlaneWaveBasis{T}, eigenvalues::AbstractVector, εF::Number;
temperature=basis.model.temperature,
smearing=basis.model.smearing) where {T}
# Check that eigenvalues are increasing monotonically, for contiguous occupations
if !all(all(diff(εk) .≥ -eps(T)) for εk in eigenvalues)
error("Eigenvalues should be monotonically increasing.")
end

# This is needed to get the right behaviour for special floating-point types
# such as intervals.
inverse_temperature = iszero(temperature) ? T(Inf) : 1/temperature
Expand Down Expand Up @@ -207,3 +212,18 @@ function check_full_occupation(basis::PlaneWaveBasis, occupation)
all(occ_k .== filled_occ) || error("Only full occupation is supported, but $occ_k has partial occupation.")
end
end

"""
Return ranges of occupied elements based on a given occupation threshold
"""
function occupied_empty_masks(occupation, occupation_threshold)
n_occ = map(occupation) do occ
n = count(occ_i -> abs(occ_i) > occupation_threshold, occ)
# Check that all occupied elements are contiguous, otherwise range is wrong
@assert all(occ[1:n] .> occupation_threshold)
n
end
mask_occ = [1:n_occ[ik] for ik in 1:length(occupation)]
mask_empty = [(n_occ[ik] + 1):length(occupation[ik]) for ik in 1:length(occupation)]
(; mask_occ, mask_empty)
end
6 changes: 3 additions & 3 deletions src/orbitals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ using Random # Used to have a generic API for CPU and GPU computations alike: s
# threshold is a parameter to distinguish between states we want to keep and the
# others when using temperature. It is set to 0.0 by default, to treat with insulators.
function select_occupied_orbitals(basis, ψ, occupation; threshold=0.0)
N = [something(findlast(x -> x > threshold, occk), 0) for occk in occupation]
selected_ψ = [@view ψk[:, 1:N[ik]] for (ik, ψk) in enumerate(ψ)]
selected_occ = [ occk[1:N[ik]] for (ik, occk) in enumerate(occupation)]
mask_occ = occupied_empty_masks(occupation, threshold).mask_occ
selected_ψ = [@view ψk[:, mask_occ[ik]] for (ik, ψk) in enumerate(ψ)]
selected_occ = [ occk[mask_occ[ik]] for (ik, occk) in enumerate(occupation)]

# if we have an insulator, sanity check that the orbitals we kept are the
# occupied ones
Expand Down
6 changes: 2 additions & 4 deletions src/response/chi0.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,7 @@ to the Hamiltonian change `δH` represented by the matrix-vector products `δHψ
# We then use the extra information we have from these additional bands,
# non-necessarily converged, to split the Sternheimer_solver with a Schur
# complement.
occ_thresh = occupation_threshold
mask_occ = map(occk -> findall(occnk -> abs(occnk) ≥ occ_thresh, occk), occupation)
mask_extra = map(occk -> findall(occnk -> abs(occnk) < occ_thresh, occk), occupation)
(mask_occ, mask_extra) = occupied_empty_masks(occupation, occupation_threshold)

ψ_occ = [ψ[ik][:, maskk] for (ik, maskk) in enumerate(mask_occ)]
ψ_extra = [ψ[ik][:, maskk] for (ik, maskk) in enumerate(mask_extra)]
Expand Down Expand Up @@ -563,7 +561,7 @@ function construct_bandtol(Bandtol::Type, basis::PlaneWaveBasis, ψ, occupation:
Ω = basis.model.unit_cell_volume
Ng = prod(basis.fft_size)
Nk = length(basis.kpoints)
mask_occ = map(ok -> findall(onk -> abs(onk) ≥ occupation_threshold, ok), occupation)
mask_occ = occupied_empty_masks(occupation, occupation_threshold).mask_occ

# Including k-points the expression (3.11) in 2505.02319 becomes
# with Φk = (ψ_{1,k} … ψ_{n,k})_k (Concatenation of all occupied orbitals for this k)
Expand Down
3 changes: 3 additions & 0 deletions src/supercell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ function cell_to_supercell(scfres::NamedTuple)
basis_supercell = cell_to_supercell(basis)
ψ_supercell = [cell_to_supercell(ψ, basis, basis_supercell)]
eigs_supercell = [reduce(vcat, scfres_unfold.eigenvalues)]
perms = [sortperm(eigs_supercell[ik]) for ik = 1:length(eigs_supercell)]
ψ_supercell = [ψ_supercell[ik][:, perms[ik]] for ik = 1:length(ψ_supercell)]
eigs_supercell = [eigs_supercell[ik][perms[ik]] for ik = 1:length(eigs_supercell)]
occ_supercell = compute_occupation(basis_supercell, eigs_supercell, scfres.εF).occupation
ρ_supercell = compute_density(basis_supercell, ψ_supercell, occ_supercell;
scfres.occupation_threshold)
Expand Down
Loading