Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ext/DFTKAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ using PrecompileTools
using LinearAlgebra
import DFTK: GPU, precompilation_workflow
using DFTK
import ForwardDiff
import ForwardDiff: Dual
Comment on lines +7 to +8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can also be removed, right ?


DFTK.synchronize_device(::GPU{<:AMDGPU.ROCArray}) = AMDGPU.synchronize()

Expand Down Expand Up @@ -39,5 +41,4 @@ if AMDGPU.functional()
end
end
end

end
2 changes: 1 addition & 1 deletion src/DispatchFunctional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function DftFunctionals.has_energy(func::LibxcFunctional)
0 in Libxc.supported_derivatives(Libxc.Functional(func.identifier))
end

function libxc_unfold_spin(data::Matrix, n_spin::Int)
function libxc_unfold_spin(data::AbstractMatrix, n_spin::Int)
n_p = size(data, 2)
if n_spin == 1
data # Only one spin component
Expand Down
6 changes: 3 additions & 3 deletions src/gpu/gpu_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ function LinearAlgebra.norm(A::Hermitian{T, <:AbstractGPUArray}) where {T}
sqrt(2upper_triangle - diago)
end


# Make sure that there is a CPU fallback for AbstractGPUArrays (e.g. for Duals)
for fun in (:potential_terms, :kernel_terms)
@eval function DftFunctionals.$fun(fun::DispatchFunctional, ρ::AT,
args...) where {AT <: AbstractGPUArray{Float64}}
args...) where {AT <: AbstractGPUArray}
# Fallback implementation for the GPU: Transfer to the CPU and run computation there
cpuify(::Nothing) = nothing
cpuify(x::AbstractArray) = Array(x)
$fun(fun, Array(ρ), cpuify.(args)...)
end
end
end
2 changes: 2 additions & 0 deletions src/response/chi0.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ to the Hamiltonian change `δH` represented by the matrix-vector products `δHψ
# We then use the extra information we have from these additional bands,
# non-necessarily converged, to split the Sternheimer_solver with a Schur
# complement.
occupation = [to_cpu(oc) for oc in occupation]
(mask_occ, mask_extra) = occupied_empty_masks(occupation, occupation_threshold)

ψ_occ = [ψ[ik][:, maskk] for (ik, maskk) in enumerate(mask_occ)]
Expand Down Expand Up @@ -561,6 +562,7 @@ function construct_bandtol(Bandtol::Type, basis::PlaneWaveBasis, ψ, occupation:
Ω = basis.model.unit_cell_volume
Ng = prod(basis.fft_size)
Nk = length(basis.kpoints)
occupation = [to_cpu(oc) for oc in occupation]
mask_occ = occupied_empty_masks(occupation, occupation_threshold).mask_occ

# Including k-points the expression (3.11) in 2505.02319 becomes
Expand Down
12 changes: 7 additions & 5 deletions src/terms/local_nonlinearity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ struct TermLocalNonlinearity{TF} <: TermNonlinear
end
(L::LocalNonlinearity)(::AbstractBasis) = TermLocalNonlinearity(L.f)

# FD on the GPU, when T<:Dual causes all sorts of troubles, at least on AMD. TODO: also on NVIDIA?
# TODO: only transfer to CPU when T <: Dual ?, or only if ROCArray?
function ene_ops(term::TermLocalNonlinearity, basis::PlaneWaveBasis{T}, ψ, occupation;
ρ, kwargs...) where {T}
fp(ρ) = ForwardDiff.derivative(term.f, ρ)
E = sum(fρ -> convert_dual(T, fρ), term.f.(ρ)) * basis.dvol
potential = convert_dual.(T, fp.(ρ))
potential = to_device(basis.architecture, convert_dual.(T, fp.(to_cpu(ρ))))

# In the case of collinear spin, the potential is spin-dependent
ops = [RealSpaceMultiplication(basis, kpt, potential[:, :, :, kpt.spin])
Expand All @@ -22,16 +24,16 @@ function ene_ops(term::TermLocalNonlinearity, basis::PlaneWaveBasis{T}, ψ, occu
end


function compute_kernel(term::TermLocalNonlinearity, ::AbstractBasis{T}; ρ, kwargs...) where {T}
function compute_kernel(term::TermLocalNonlinearity, basis::AbstractBasis{T}; ρ, kwargs...) where {T}
fp(ρ) = ForwardDiff.derivative(term.f, ρ)
fpp(ρ) = ForwardDiff.derivative(fp, ρ)
Diagonal(vec(convert_dual.(T, fpp.(ρ))))
Diagonal(to_device(basis.architecture, vec(convert_dual.(T, fpp.(to_cpu(ρ))))))
end

function apply_kernel(term::TermLocalNonlinearity, ::AbstractBasis{T},
function apply_kernel(term::TermLocalNonlinearity, basis::AbstractBasis{T},
δρ::AbstractArray{Tδρ}; ρ, kwargs...) where {T, Tδρ}
S = promote_type(T, Tδρ)
fp(ρ) = ForwardDiff.derivative(term.f, ρ)
fpp(ρ) = ForwardDiff.derivative(fp, ρ)
convert_dual.(S, fpp.(ρ) .* δρ)
to_device(basis.architecture, convert_dual.(S, fpp.(to_cpu(ρ)) .* to_cpu(δρ)))
end
14 changes: 11 additions & 3 deletions src/terms/xc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ function (xc::Xc)(basis::PlaneWaveBasis{T}) where {T}
# Strip duals from functional parameters if needed
params = parameters(fun)
if !isempty(params)
newparams = convert_dual.(T, params)
newparams = map(p -> convert_dual(T, p), params)
fun = change_parameters(fun, newparams; keep_identifier=true)
end
fun
Expand Down Expand Up @@ -427,7 +427,7 @@ function apply_kernel(term::TermXc, basis::PlaneWaveBasis{T}, δρ::AbstractArra

# If the XC functional is not supported for an architecture, terms is on the CPU
terms = kernel_terms(term.functionals, density)
δV = zeros(Tδρ, size(ρ)...) # [ix, iy, iz, iσ]
δV = zeros_like(δρ, Tδρ, size(ρ)...) # [ix, iy, iz, iσ]

Vρρ = to_device(basis.architecture, reshape(terms.Vρρ, n_spin, n_spin, basis.fft_size...))
@views for s = 1:n_spin, t = 1:n_spin # LDA term
Expand Down Expand Up @@ -529,11 +529,19 @@ _matify(data::AbstractArray) = reshape(data, size(data, 1), :)

for fun in (:potential_terms, :kernel_terms)
@eval begin
function DftFunctionals.$fun(xc::Functional, density::LibxcDensities)
function DftFunctionals.$fun(xc::DispatchFunctional, density::LibxcDensities)
$fun(xc, _matify(density.ρ_real), _matify(density.σ_real),
_matify(density.τ_real), _matify(density.Δρ_real))
end

# Ensure functionals from DftFunctionals are sent to the CPU, until DftFunctionals.jl is refactored
function DftFunctionals.$fun(fun::DftFunctionals.Functional, density::LibxcDensities)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks weird to me. Why is this needed on top of the above ? Should the types not somehow depend on a GPU type here ?

Copy link
Collaborator Author

@abussy abussy Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tricky bit is to avoid ambiguity with the internal definitions of DftFunctionals.jl.

Adding the following to src/gpu/gpu_array.jl leads to ambiguity, because it does not specialize on the type of functional (:lda, :gga, or :mgga):

for fun in (:potential_terms, :kernel_terms)
    @eval function DftFunctionals.$fun(fun::DftFunctionals.Functional, ρ::AT,
                                       args...) where {AT <: AbstractGPUArray}
        # Fallback implementation for the GPU: Transfer to the CPU and run computation there
        cpuify(::Nothing) = nothing
        cpuify(x::AbstractArray) = Array(x)
        $fun(fun, Array(ρ), cpuify.(args)...)
    end
end

Either I write the above for each functional type (lot of code duplication), or I parametrize it with a second loop over functional types, e.g.:

for fun in (:potential_terms, :kernel_terms), ftype in (:lda, :gga, :mgga)
    @eval function DftFunctionals.$fun(fun::DftFunctionals.Functional{$(QuoteNode(ftype))},
                                       ρ::AT, args...) where {AT <: AbstractGPUArray}
        # Fallback implementation for the GPU: Transfer to the CPU and run computation there
        cpuify(::Nothing) = nothing
        cpuify(x::AbstractArray) = Array(x)
        $fun(fun, Array(ρ), cpuify.(args)...)
    end
end

I don't like any of these alternative very much. I think the current solution carries a very clear message: anything DftFunctionals related goes to the CPU.

maticpuify(::Nothing) = nothing
maticpuify(x::AbstractArray) = reshape(Array(x), size(x, 1), :)
DftFunctionals.$fun(fun, maticpuify(density.ρ_real), maticpuify(density.σ_real),
maticpuify(density.τ_real), maticpuify(density.Δρ_real))
end

function DftFunctionals.$fun(xcs::Vector{Functional}, density::LibxcDensities)
isempty(xcs) && return NamedTuple()
result = $fun(xcs[1], density)
Expand Down
28 changes: 15 additions & 13 deletions src/workarounds/forwarddiff_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ end
function LinearAlgebra.mul!(y::AbstractArray{<:Union{Complex{<:Dual}}},
p::AbstractFFTs.Plan,
x::AbstractArray{<:Union{Complex{<:Dual}}})
copyto!(y, p*x)
copyto!(y, _mul(p, x))
end
function Base.:*(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:Dual{Tg}}}) where {Tg}
function _mul(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:Dual{Tg}}}) where {Tg}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again this feels strange and is surprising to me. Why did you need this ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this workaround, the GPU compiler throws an invalid LLVM IR error during stress calculations. I think there is confusion around which method of Base.:* to use, but I don't understand why.

# TODO do we want x::AbstractArray{<:Dual{T}} too?
xtil = p * ForwardDiff.value.(x)
dxtils = ntuple(ForwardDiff.npartials(eltype(x))) do n
Expand All @@ -46,6 +46,8 @@ function Base.:*(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:Dual{Tg}}})
)
end
end
Base.:*(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:Dual}}) = _mul(p, x)
Base.:*(p::DummyInplace, x::AbstractArray{<:Union{Complex{<:Dual}}}) = copyto!(x, _mul(p.fft, x))

function build_fft_plans!(tmp::AbstractArray{Complex{T}}) where {T<:Dual}
opFFT = AbstractFFTs.plan_fft(tmp)
Expand Down Expand Up @@ -219,10 +221,9 @@ function construct_value(basis::PlaneWaveBasis{T}) where {T <: Dual}
end


@timing "self_consistent_field ForwardDiff" function self_consistent_field(
basis_dual::PlaneWaveBasis{T};
response=ResponseOptions(),
kwargs...) where {T <: Dual}
@timing "self_consistent_field ForwardDiff" function self_consistent_field(basis_dual::PlaneWaveBasis{<:Dual{Tg,V,N}};
response=ResponseOptions(),
kwargs...) where {Tg,V,N}
# Note: No guarantees on this interface yet.

# Primal pass
Expand All @@ -241,28 +242,29 @@ end
end
# Implicit differentiation
response.verbose && println("Solving response problem")
δresults = ntuple(ForwardDiff.npartials(T)) do α
δresults = ntuple(N) do α
δHextψ = [ForwardDiff.partials.(δHextψk, α) for δHextψk in Hψ_dual]
δtemperature = ForwardDiff.partials(basis_dual.model.temperature, α)
solve_ΩplusK_split(scfres, δHextψ; δtemperature,
tol=last(scfres.history_Δρ), response.verbose)
end

# Convert and combine
DT = Dual{ForwardDiff.tagtype(T)}
ψ = map(scfres.ψ, getfield.(δresults, :δψ)...) do ψk, δψk...
map(ψk, δψk...) do ψnk, δψnk...
Complex(DT(real(ψnk), real.(δψnk)),
DT(imag(ψnk), imag.(δψnk)))
Complex(Dual{Tg}(real(ψnk), real.(δψnk)),
Dual{Tg}(imag(ψnk), imag.(δψnk)))
end
end
eigenvalues = map(scfres.eigenvalues, getfield.(δresults, :δeigenvalues)...) do εk, δεk...
map((εnk, δεnk...) -> DT(εnk, δεnk), εk, δεk...)
map((εnk, δεnk...) -> Dual{Tg}(εnk, δεnk), εk, δεk...)
end
occupation = map(scfres.occupation, getfield.(δresults, :δoccupation)...) do occk, δocck...
map((occnk, δoccnk...) -> DT(occnk, δoccnk), occk, δocck...)
occk_cpu = to_cpu(occk)
to_device(basis_dual.architecture,
map((occnk, δoccnk...) -> Dual{Tg}(occnk, δoccnk), occk_cpu, δocck...))
end
εF = DT(scfres.εF, getfield.(δresults, :δεF)...)
εF = Dual{Tg}(scfres.εF, getfield.(δresults, :δεF)...)

# For strain, basis_dual contributes an explicit lattice contribution which
# is not contained in δresults, so we need to recompute ρ here
Expand Down
Loading
Loading