-
Notifications
You must be signed in to change notification settings - Fork 97
Port stress/response calculations to the GPU #1187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,7 +42,7 @@ function (xc::Xc)(basis::PlaneWaveBasis{T}) where {T} | |
| # Strip duals from functional parameters if needed | ||
| params = parameters(fun) | ||
| if !isempty(params) | ||
| newparams = convert_dual.(T, params) | ||
| newparams = map(p -> convert_dual(T, p), params) | ||
| fun = change_parameters(fun, newparams; keep_identifier=true) | ||
| end | ||
| fun | ||
|
|
@@ -427,7 +427,7 @@ function apply_kernel(term::TermXc, basis::PlaneWaveBasis{T}, δρ::AbstractArra | |
|
|
||
| # If the XC functional is not supported for an architecture, terms is on the CPU | ||
| terms = kernel_terms(term.functionals, density) | ||
| δV = zeros(Tδρ, size(ρ)...) # [ix, iy, iz, iσ] | ||
| δV = zeros_like(δρ, Tδρ, size(ρ)...) # [ix, iy, iz, iσ] | ||
|
|
||
| Vρρ = to_device(basis.architecture, reshape(terms.Vρρ, n_spin, n_spin, basis.fft_size...)) | ||
| @views for s = 1:n_spin, t = 1:n_spin # LDA term | ||
|
|
@@ -529,11 +529,19 @@ _matify(data::AbstractArray) = reshape(data, size(data, 1), :) | |
|
|
||
| for fun in (:potential_terms, :kernel_terms) | ||
| @eval begin | ||
| function DftFunctionals.$fun(xc::Functional, density::LibxcDensities) | ||
| function DftFunctionals.$fun(xc::DispatchFunctional, density::LibxcDensities) | ||
| $fun(xc, _matify(density.ρ_real), _matify(density.σ_real), | ||
| _matify(density.τ_real), _matify(density.Δρ_real)) | ||
| end | ||
|
|
||
| # Ensure functionals from DftFunctionals are sent to the CPU, until DftFunctionals.jl is refactored | ||
| function DftFunctionals.$fun(fun::DftFunctionals.Functional, density::LibxcDensities) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks weird to me. Why is this needed on top of the above ? Should the types not somehow depend on a GPU type here ?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The tricky bit is to avoid ambiguity with the internal definitions of DftFunctionals.jl. Adding the following to Either I write the above for each functional type (lot of code duplication), or I parametrize it with a second loop over functional types, e.g.: I don't like any of these alternative very much. I think the current solution carries a very clear message: anything DftFunctionals related goes to the CPU. |
||
| maticpuify(::Nothing) = nothing | ||
| maticpuify(x::AbstractArray) = reshape(Array(x), size(x, 1), :) | ||
| DftFunctionals.$fun(fun, maticpuify(density.ρ_real), maticpuify(density.σ_real), | ||
| maticpuify(density.τ_real), maticpuify(density.Δρ_real)) | ||
| end | ||
|
|
||
| function DftFunctionals.$fun(xcs::Vector{Functional}, density::LibxcDensities) | ||
| isempty(xcs) && return NamedTuple() | ||
| result = $fun(xcs[1], density) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,9 +31,9 @@ end | |
| function LinearAlgebra.mul!(y::AbstractArray{<:Union{Complex{<:Dual}}}, | ||
| p::AbstractFFTs.Plan, | ||
| x::AbstractArray{<:Union{Complex{<:Dual}}}) | ||
| copyto!(y, p*x) | ||
| copyto!(y, _mul(p, x)) | ||
| end | ||
| function Base.:*(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:Dual{Tg}}}) where {Tg} | ||
| function _mul(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:Dual{Tg}}}) where {Tg} | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again this feels strange and is surprising to me. Why did you need this ?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without this workaround, the GPU compiler throws an |
||
| # TODO do we want x::AbstractArray{<:Dual{T}} too? | ||
| xtil = p * ForwardDiff.value.(x) | ||
| dxtils = ntuple(ForwardDiff.npartials(eltype(x))) do n | ||
|
|
@@ -46,6 +46,8 @@ function Base.:*(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:Dual{Tg}}}) | |
| ) | ||
| end | ||
| end | ||
| Base.:*(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:Dual}}) = _mul(p, x) | ||
| Base.:*(p::DummyInplace, x::AbstractArray{<:Union{Complex{<:Dual}}}) = copyto!(x, _mul(p.fft, x)) | ||
|
|
||
| function build_fft_plans!(tmp::AbstractArray{Complex{T}}) where {T<:Dual} | ||
| opFFT = AbstractFFTs.plan_fft(tmp) | ||
|
|
@@ -219,10 +221,9 @@ function construct_value(basis::PlaneWaveBasis{T}) where {T <: Dual} | |
| end | ||
|
|
||
|
|
||
| @timing "self_consistent_field ForwardDiff" function self_consistent_field( | ||
| basis_dual::PlaneWaveBasis{T}; | ||
| response=ResponseOptions(), | ||
| kwargs...) where {T <: Dual} | ||
| @timing "self_consistent_field ForwardDiff" function self_consistent_field(basis_dual::PlaneWaveBasis{<:Dual{Tg,V,N}}; | ||
| response=ResponseOptions(), | ||
| kwargs...) where {Tg,V,N} | ||
| # Note: No guarantees on this interface yet. | ||
|
|
||
| # Primal pass | ||
|
|
@@ -241,28 +242,29 @@ end | |
| end | ||
| # Implicit differentiation | ||
| response.verbose && println("Solving response problem") | ||
| δresults = ntuple(ForwardDiff.npartials(T)) do α | ||
| δresults = ntuple(N) do α | ||
| δHextψ = [ForwardDiff.partials.(δHextψk, α) for δHextψk in Hψ_dual] | ||
| δtemperature = ForwardDiff.partials(basis_dual.model.temperature, α) | ||
| solve_ΩplusK_split(scfres, δHextψ; δtemperature, | ||
| tol=last(scfres.history_Δρ), response.verbose) | ||
| end | ||
|
|
||
| # Convert and combine | ||
| DT = Dual{ForwardDiff.tagtype(T)} | ||
| ψ = map(scfres.ψ, getfield.(δresults, :δψ)...) do ψk, δψk... | ||
| map(ψk, δψk...) do ψnk, δψnk... | ||
| Complex(DT(real(ψnk), real.(δψnk)), | ||
| DT(imag(ψnk), imag.(δψnk))) | ||
| Complex(Dual{Tg}(real(ψnk), real.(δψnk)), | ||
| Dual{Tg}(imag(ψnk), imag.(δψnk))) | ||
| end | ||
| end | ||
| eigenvalues = map(scfres.eigenvalues, getfield.(δresults, :δeigenvalues)...) do εk, δεk... | ||
| map((εnk, δεnk...) -> DT(εnk, δεnk), εk, δεk...) | ||
| map((εnk, δεnk...) -> Dual{Tg}(εnk, δεnk), εk, δεk...) | ||
| end | ||
| occupation = map(scfres.occupation, getfield.(δresults, :δoccupation)...) do occk, δocck... | ||
| map((occnk, δoccnk...) -> DT(occnk, δoccnk), occk, δocck...) | ||
| occk_cpu = to_cpu(occk) | ||
| to_device(basis_dual.architecture, | ||
| map((occnk, δoccnk...) -> Dual{Tg}(occnk, δoccnk), occk_cpu, δocck...)) | ||
| end | ||
| εF = DT(scfres.εF, getfield.(δresults, :δεF)...) | ||
| εF = Dual{Tg}(scfres.εF, getfield.(δresults, :δεF)...) | ||
|
|
||
| # For strain, basis_dual contributes an explicit lattice contribution which | ||
| # is not contained in δresults, so we need to recompute ρ here | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can also be removed, right ?