-
Notifications
You must be signed in to change notification settings - Fork 97
Port stress/response calculations to the GPU #1187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,3 +23,14 @@ for fun in (:potential_terms, :kernel_terms) | |
| $fun(fun, Array(ρ), cpuify.(args)...) | ||
| end | ||
| end | ||
|
|
||
| # Make sure that computations done by DftFunctionals.jl are done on the CPU (until refactoring) | ||
| for fun in (:potential_terms, :kernel_terms) | ||
| @eval function DftFunctionals.$fun(fun::DispatchFunctional, ρ::AT, | ||
| args...) where {AT <: AbstractGPUArray} | ||
| # Fallback implementation for the GPU: Transfer to the CPU and run computation there | ||
| cpuify(::Nothing) = nothing | ||
| cpuify(x::AbstractArray) = Array(x) | ||
| $fun(fun, Array(ρ), cpuify.(args)...) | ||
| end | ||
|
||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,7 +42,7 @@ function (xc::Xc)(basis::PlaneWaveBasis{T}) where {T} | |
| # Strip duals from functional parameters if needed | ||
| params = parameters(fun) | ||
| if !isempty(params) | ||
| newparams = convert_dual.(T, params) | ||
| newparams = map(p -> convert_dual(T, p), params) | ||
| fun = change_parameters(fun, newparams; keep_identifier=true) | ||
| end | ||
| fun | ||
|
|
@@ -427,7 +427,7 @@ function apply_kernel(term::TermXc, basis::PlaneWaveBasis{T}, δρ::AbstractArra | |
|
|
||
| # If the XC functional is not supported for an architecture, terms is on the CPU | ||
| terms = kernel_terms(term.functionals, density) | ||
| δV = zeros(Tδρ, size(ρ)...) # [ix, iy, iz, iσ] | ||
| δV = zeros_like(δρ, Tδρ, size(ρ)...) # [ix, iy, iz, iσ] | ||
|
|
||
| Vρρ = to_device(basis.architecture, reshape(terms.Vρρ, n_spin, n_spin, basis.fft_size...)) | ||
| @views for s = 1:n_spin, t = 1:n_spin # LDA term | ||
|
|
@@ -529,11 +529,19 @@ _matify(data::AbstractArray) = reshape(data, size(data, 1), :) | |
|
|
||
| for fun in (:potential_terms, :kernel_terms) | ||
| @eval begin | ||
| function DftFunctionals.$fun(xc::Functional, density::LibxcDensities) | ||
| function DftFunctionals.$fun(xc::DispatchFunctional, density::LibxcDensities) | ||
| $fun(xc, _matify(density.ρ_real), _matify(density.σ_real), | ||
| _matify(density.τ_real), _matify(density.Δρ_real)) | ||
| end | ||
|
|
||
| # Ensure functionals from DftFunctionals are sent to the CPU, until DftFunctionals.jl is refactored | ||
| function DftFunctionals.$fun(fun::DftFunctionals.Functional, density::LibxcDensities) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks weird to me. Why is this needed on top of the above ? Should the types not somehow depend on a GPU type here ?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The tricky bit is to avoid ambiguity with the internal definitions of DftFunctionals.jl. Adding the following to Either I write the above for each functional type (lot of code duplication), or I parametrize it with a second loop over functional types, e.g.: I don't like any of these alternative very much. I think the current solution carries a very clear message: anything DftFunctionals related goes to the CPU. |
||
| maticpuify(::Nothing) = nothing | ||
| maticpuify(x::AbstractArray) = reshape(Array(x), size(x, 1), :) | ||
| DftFunctionals.$fun(fun, maticpuify(density.ρ_real), maticpuify(density.σ_real), | ||
| maticpuify(density.τ_real), maticpuify(density.Δρ_real)) | ||
| end | ||
|
|
||
| function DftFunctionals.$fun(xcs::Vector{Functional}, density::LibxcDensities) | ||
| isempty(xcs) && return NamedTuple() | ||
| result = $fun(xcs[1], density) | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -31,9 +31,9 @@ end | |||||||||||||
| function LinearAlgebra.mul!(y::AbstractArray{<:Union{Complex{<:Dual}}}, | ||||||||||||||
| p::AbstractFFTs.Plan, | ||||||||||||||
| x::AbstractArray{<:Union{Complex{<:Dual}}}) | ||||||||||||||
| copyto!(y, p*x) | ||||||||||||||
| copyto!(y, _mul(p, x)) | ||||||||||||||
| end | ||||||||||||||
| function Base.:*(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:Dual{Tg}}}) where {Tg} | ||||||||||||||
| function _mul(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:Dual{Tg}}}) where {Tg} | ||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again this feels strange and is surprising to me. Why did you need this ?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without this workaround, the GPU compiler throws an |
||||||||||||||
| # TODO do we want x::AbstractArray{<:Dual{T}} too? | ||||||||||||||
| xtil = p * ForwardDiff.value.(x) | ||||||||||||||
| dxtils = ntuple(ForwardDiff.npartials(eltype(x))) do n | ||||||||||||||
|
|
@@ -46,6 +46,8 @@ function Base.:*(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:Dual{Tg}}}) | |||||||||||||
| ) | ||||||||||||||
| end | ||||||||||||||
| end | ||||||||||||||
| Base.:*(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:Dual}}) = _mul(p, x) | ||||||||||||||
| Base.:*(p::DummyInplace, x::AbstractArray{<:Union{Complex{<:Dual}}}) = copyto!(x, _mul(p.fft, x)) | ||||||||||||||
|
|
||||||||||||||
| function build_fft_plans!(tmp::AbstractArray{Complex{T}}) where {T<:Dual} | ||||||||||||||
| opFFT = AbstractFFTs.plan_fft(tmp) | ||||||||||||||
|
|
@@ -219,10 +221,9 @@ function construct_value(basis::PlaneWaveBasis{T}) where {T <: Dual} | |||||||||||||
| end | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| @timing "self_consistent_field ForwardDiff" function self_consistent_field( | ||||||||||||||
| basis_dual::PlaneWaveBasis{T}; | ||||||||||||||
| response=ResponseOptions(), | ||||||||||||||
| kwargs...) where {T <: Dual} | ||||||||||||||
| @timing "self_consistent_field ForwardDiff" function self_consistent_field(basis_dual::PlaneWaveBasis{<:Dual{T,V,N}}; | ||||||||||||||
| response=ResponseOptions(), | ||||||||||||||
| kwargs...) where {T,V,N} | ||||||||||||||
| # Note: No guarantees on this interface yet. | ||||||||||||||
|
|
||||||||||||||
| # Primal pass | ||||||||||||||
|
|
@@ -241,28 +242,29 @@ end | |||||||||||||
| end | ||||||||||||||
| # Implicit differentiation | ||||||||||||||
| response.verbose && println("Solving response problem") | ||||||||||||||
| δresults = ntuple(ForwardDiff.npartials(T)) do α | ||||||||||||||
| δresults = ntuple(N) do α | ||||||||||||||
| δHextψ = [ForwardDiff.partials.(δHextψk, α) for δHextψk in Hψ_dual] | ||||||||||||||
| δtemperature = ForwardDiff.partials(basis_dual.model.temperature, α) | ||||||||||||||
| solve_ΩplusK_split(scfres, δHextψ; δtemperature, | ||||||||||||||
| tol=last(scfres.history_Δρ), response.verbose) | ||||||||||||||
| end | ||||||||||||||
|
|
||||||||||||||
| # Convert and combine | ||||||||||||||
| DT = Dual{ForwardDiff.tagtype(T)} | ||||||||||||||
| ψ = map(scfres.ψ, getfield.(δresults, :δψ)...) do ψk, δψk... | ||||||||||||||
| map(ψk, δψk...) do ψnk, δψnk... | ||||||||||||||
| Complex(DT(real(ψnk), real.(δψnk)), | ||||||||||||||
| DT(imag(ψnk), imag.(δψnk))) | ||||||||||||||
| Complex(Dual{T}(real(ψnk), real.(δψnk)), | ||||||||||||||
| Dual{T}(imag(ψnk), imag.(δψnk))) | ||||||||||||||
|
||||||||||||||
| function LinearAlgebra.norm(x::SVector{S,<:Dual{Tg,T,N}}) where {S,Tg,T,N} | |
| x_value = ForwardDiff.value.(x) | |
| y = norm(x_value) | |
| dy = ntuple(j->real(dot(x_value, ForwardDiff.partials.(x,j))) * pinv(y), N) | |
| Dual{Tg}(y, dy) | |
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah no, I mean that Dual{T} is fine. I misread it the first time. What is weird to me is why this is necessary, given that the gpu compiler isn't operating directly on this method? (Or is it?)
Looking at @code_typed might be quite insightful
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I didn't see @Technici4n last message when I wrote the above.
The GPU compilers requires this to be able to compile. I am not sure what the root cause is, but type instability is a likely candidate. Generally, the GPU compiler is rather bad at type inference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can also be removed, right ?