Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/DFTK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ include("workarounds/forwarddiff_rules.jl")
# Optimized generic GPU functions and GPU workarounds
include("gpu/linalg.jl")
include("gpu/gpu_arrays.jl")
include("gpu/local.jl")

# Precompilation block with a basic workflow

Expand Down
2 changes: 1 addition & 1 deletion src/common/spherical_bessels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ with `SpecialFunctions.sphericalbesselj`. Specialized for integer ``0 ≤ l ≤
l == 3 && return (sin(x) * (15 - 6x^2) + cos(x) * (x^3 - 15x)) / x^4
l == 4 && return (sin(x) * (105 - 45x^2 + x^4) + cos(x) * (10x^3 - 105x)) / x^5
l == 5 && return (sin(x) * (945 - 420x^2 + 15x^4) + cos(x) * (-945x + 105x^3 - x^5)) / x^6
error("The case l = $l is not implemented")
throw(BoundsError()) # specific l not implemented
end
34 changes: 34 additions & 0 deletions src/gpu/local.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# GPU workarounds for atomic grid integrations of the AtomicLocal term. GPU kernels can
# only take isbits data as input, and Upf elements are far from being isbits. Since only
# a limited number of operations ever become rate limiting, we simply rewrite
# those in a GPU optimized way here.
function atomic_local_inner_loop!(form_factors_cpu, norm_indices, igroup,
element::ElementPsp{<:PspUpf}, arch::GPU{AT}) where {AT}

x = @view element.psp.rgrid[1:3]
uniform_grid = (x[2] - x[1]) ≈ (x[3] - x[2]) ? true : false

rgrid = to_device(arch, @view element.psp.rgrid[1:element.psp.ircut])
vloc = to_device(arch, @view element.psp.vloc[1:element.psp.ircut])
ps = to_device(arch, collect(keys(norm_indices)))
Zion = element.psp.Zion

ints = map(ps) do p
T = eltype(p)
method = uniform_grid ? simpson_uniform : simpson_nonuniform
if p == 0
zero(T)
else
# GPU compilation error if branching done within generic simpson() function
I = method(rgrid) do i, r
r * (r * vloc[i] - -Zion * erf(r)) * sphericalbesselj_fast(0, p * r)
end
4T(π) * (I + -Zion / p^2 * exp(-p^2 / T(4)))
end
end

ints_cpu = to_cpu(ints)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is weird. The only reason this is on the CPU is because we needed it for something. Now essentially you put form_factors on the CPU only to put it on the GPU once this function call is over, right ? Can this not directly be a GPU array in the GPU version of the code ?

for (p, I) in zip(keys(norm_indices), ints_cpu)
form_factors_cpu[norm_indices[p], igroup] = I
end
end
14 changes: 9 additions & 5 deletions src/terms/local.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,21 @@ function atomic_local_form_factors(basis::PlaneWaveBasis{T}; q=zero(Vec3{T})) wh
end

form_factors_cpu = zeros(T, length(norm_indices), length(basis.model.atom_groups))
for(p, ifnorm) in norm_indices
for (igroup, group) in enumerate(basis.model.atom_groups)
element = basis.model.atoms[first(group)]
form_factors_cpu[ifnorm, igroup] = local_potential_fourier(element, p)
end
for (igroup, group) in enumerate(basis.model.atom_groups)
element = basis.model.atoms[first(group)]
atomic_local_inner_loop!(form_factors_cpu, norm_indices, igroup, element, basis.architecture)
end

form_factors = to_device(basis.architecture, form_factors_cpu)
iG2ifnorm = to_device(basis.architecture, iG2ifnorm_cpu)
(; form_factors, iG2ifnorm)
end
function atomic_local_inner_loop!(form_factors_cpu, norm_indices, igroup,
element::Element, arch::AbstractArchitecture)
for (p, ifnorm) in norm_indices
form_factors_cpu[ifnorm, igroup] = local_potential_fourier(element, p)
end
end

## Atomic local potential
struct TermAtomicLocal{AT} <: TermLocalPotential
Expand Down
Loading