diff --git a/src/occupation.jl b/src/occupation.jl index 14d1aca5f9..301dfa7904 100644 --- a/src/occupation.jl +++ b/src/occupation.jl @@ -30,6 +30,11 @@ end function compute_occupation(basis::PlaneWaveBasis{T}, eigenvalues::AbstractVector, εF::Number; temperature=basis.model.temperature, smearing=basis.model.smearing) where {T} + # Check that eigenvalues are increasing monotonically, for contiguous occupations + if !all(all(diff(εk) .≥ -eps(T)) for εk in eigenvalues) + error("Eigenvalues should be monotonically increasing.") + end + # This is needed to get the right behaviour for special floating-point types # such as intervals. inverse_temperature = iszero(temperature) ? T(Inf) : 1/temperature @@ -207,3 +212,18 @@ function check_full_occupation(basis::PlaneWaveBasis, occupation) all(occ_k .== filled_occ) || error("Only full occupation is supported, but $occ_k has partial occupation.") end end + +""" +Return ranges of occupied elements based on a given occupation threshold +""" +function occupied_empty_masks(occupation, occupation_threshold) + n_occ = map(occupation) do occ + n = count(occ_i -> abs(occ_i) > occupation_threshold, occ) + # Check that all occupied elements are contiguous, otherwise range is wrong + @assert all(occ[1:n] .> occupation_threshold) + n + end + mask_occ = [1:n_occ[ik] for ik in 1:length(occupation)] + mask_empty = [(n_occ[ik] + 1):length(occupation[ik]) for ik in 1:length(occupation)] + (; mask_occ, mask_empty) +end \ No newline at end of file diff --git a/src/orbitals.jl b/src/orbitals.jl index 3456492cfd..a472fc6dc2 100644 --- a/src/orbitals.jl +++ b/src/orbitals.jl @@ -5,9 +5,9 @@ using Random # Used to have a generic API for CPU and GPU computations alike: s # threshold is a parameter to distinguish between states we want to keep and the # others when using temperature. It is set to 0.0 by default, to treat with insulators. function select_occupied_orbitals(basis, ψ, occupation; threshold=0.0) - N = [something(findlast(x -> x > threshold, occk), 0) for occk in occupation] - selected_ψ = [@view ψk[:, 1:N[ik]] for (ik, ψk) in enumerate(ψ)] - selected_occ = [ occk[1:N[ik]] for (ik, occk) in enumerate(occupation)] + mask_occ = occupied_empty_masks(occupation, threshold).mask_occ + selected_ψ = [@view ψk[:, mask_occ[ik]] for (ik, ψk) in enumerate(ψ)] + selected_occ = [ occk[mask_occ[ik]] for (ik, occk) in enumerate(occupation)] # if we have an insulator, sanity check that the orbitals we kept are the # occupied ones diff --git a/src/response/chi0.jl b/src/response/chi0.jl index 67f51bbbf6..d13eaceeda 100644 --- a/src/response/chi0.jl +++ b/src/response/chi0.jl @@ -409,9 +409,7 @@ to the Hamiltonian change `δH` represented by the matrix-vector products `δHψ # We then use the extra information we have from these additional bands, # non-necessarily converged, to split the Sternheimer_solver with a Schur # complement. - occ_thresh = occupation_threshold - mask_occ = map(occk -> findall(occnk -> abs(occnk) ≥ occ_thresh, occk), occupation) - mask_extra = map(occk -> findall(occnk -> abs(occnk) < occ_thresh, occk), occupation) + (mask_occ, mask_extra) = occupied_empty_masks(occupation, occupation_threshold) ψ_occ = [ψ[ik][:, maskk] for (ik, maskk) in enumerate(mask_occ)] ψ_extra = [ψ[ik][:, maskk] for (ik, maskk) in enumerate(mask_extra)] @@ -563,7 +561,7 @@ function construct_bandtol(Bandtol::Type, basis::PlaneWaveBasis, ψ, occupation: Ω = basis.model.unit_cell_volume Ng = prod(basis.fft_size) Nk = length(basis.kpoints) - mask_occ = map(ok -> findall(onk -> abs(onk) ≥ occupation_threshold, ok), occupation) + mask_occ = occupied_empty_masks(occupation, occupation_threshold).mask_occ # Including k-points the expression (3.11) in 2505.02319 becomes # with Φk = (ψ_{1,k} … ψ_{n,k})_k (Concatenation of all occupied orbitals for this k) diff --git a/src/supercell.jl b/src/supercell.jl index fbc3d82d30..25d8477f89 100644 --- a/src/supercell.jl +++ b/src/supercell.jl @@ -112,6 +112,9 @@ function cell_to_supercell(scfres::NamedTuple) basis_supercell = cell_to_supercell(basis) ψ_supercell = [cell_to_supercell(ψ, basis, basis_supercell)] eigs_supercell = [reduce(vcat, scfres_unfold.eigenvalues)] + perms = [sortperm(eigs_supercell[ik]) for ik = 1:length(eigs_supercell)] + ψ_supercell = [ψ_supercell[ik][:, perms[ik]] for ik = 1:length(ψ_supercell)] + eigs_supercell = [eigs_supercell[ik][perms[ik]] for ik = 1:length(eigs_supercell)] occ_supercell = compute_occupation(basis_supercell, eigs_supercell, scfres.εF).occupation ρ_supercell = compute_density(basis_supercell, ψ_supercell, occ_supercell; scfres.occupation_threshold)