Skip to content

Commit 04b329b

Browse files
abussyniklasschmitz
authored andcommitted
Use columnwise_dots where applicable (#1127)
1 parent 617e2d5 commit 04b329b

File tree

6 files changed

+11
-23
lines changed

6 files changed

+11
-23
lines changed

src/DFTK.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ include("common/quadrature.jl")
4444
include("common/hankel.jl")
4545
include("common/hydrogenic.jl")
4646
include("common/derivatives.jl")
47+
include("common/linalg.jl")
4748

4849
export PspHgh
4950
export PspUpf
@@ -140,7 +141,6 @@ export PreconditionerNone
140141
export lobpcg_hyper
141142
export diag_full
142143
export diagonalize_all_kblocks
143-
include("eigen/linalg.jl")
144144
include("eigen/preconditioners.jl")
145145
include("eigen/diag.jl")
146146

src/eigen/lobpcg_hyper_impl.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@
3333
# other eigenvectors (which is not the case in many - all ? - other
3434
# implementations)
3535

36-
# - Some functions are reimplemented in a GPU optimized way as part of
37-
# the DFTK CUDA Extension (ext/DFTKCUDAExt/lobpcg.jl).
36+
# - Some generic linear algebra functions are used in this implementation. They can be
37+
# found in src/common/linalg.jl. GPU optimized versions of these functions are located
38+
# in src/gpu/linalg.jl.
3839

3940

4041
## TODO micro-optimization of buffer reuse

src/terms/kinetic.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,12 @@ end
4444
if isnothing(ψ) || isnothing(occupation)
4545
return (; E=T(Inf), ops)
4646
end
47-
occupation = [to_cpu(occk) for occk in occupation]
4847

4948
E = zero(T)
5049
for (ik, ψk) in enumerate(ψ)
51-
for iband = 1:size(ψk, 2)
52-
ψnk = @views ψk[:, iband]
53-
E += (basis.kweights[ik] * occupation[ik][iband]
54-
* real(dot(ψnk, Diagonal(term.kinetic_energies[ik]), ψnk)))
55-
end
50+
E += basis.kweights[ik] *
51+
sum(occupation[ik] .*
52+
real(vec(columnwise_dots(ψk, Diagonal(term.kinetic_energies[ik]), ψk))))
5653
end
5754
E = mpi_sum(E, basis.comm_kpts)
5855

src/terms/magnetic.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,7 @@ function ene_ops(term::TermMagnetic, basis::PlaneWaveBasis{T}, ψ, occupation;
5252

5353
E = zero(T)
5454
for (ik, k) in enumerate(basis.kpoints)
55-
for iband = 1:size(ψ[1], 2)
56-
ψnk = @views ψ[ik][:, iband]
57-
# TODO optimize this
58-
E += basis.kweights[ik] * occupation[ik][iband] * real(dot(ψnk, ops[ik] * ψnk))
59-
end
55+
E += basis.kweights[ik] * sum(occupation[ik] .* vec(real(columnwise_dots(ψ[ik], ops[ik] * ψ[ik]))))
6056
end
6157
E = mpi_sum(E, basis.comm_kpts)
6258

src/terms/nonlocal.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ end
7070
# We compute the forces from the irreductible BZ; they are symmetrized later.
7171
G_plus_k_cart = to_cpu(Gplusk_vectors_cart(basis, kpt))
7272
G_plus_k = Gplusk_vectors(basis, kpt)
73-
occupationk = to_cpu(occupation[ik])
7473
form_factors = to_device(basis.architecture,
7574
build_projector_form_factors(element.psp, G_plus_k_cart))
7675

@@ -91,9 +90,7 @@ end
9190
map!(p -> -2π*im*p[α], twoπp, G_plus_k)
9291
dPdR .= twoπp .* P
9392
mul!(δHψk, P, C * (dPdR' * ψ[ik]))
94-
@views -sum(occupationk[iband] * basis.kweights[ik] *
95-
2real(dot(ψ[ik][:, iband], δHψk[:, iband]))
96-
for iband=1:size(ψ[ik], 2))
93+
-basis.kweights[ik]*sum(occupation[ik] .* 2vec(real(columnwise_dots(ψ[ik], δHψk))))
9794
end # α
9895
end # r
9996
end # kpt
@@ -311,11 +308,8 @@ function compute_dynmat_δH(::TermAtomicNonlocal, basis::PlaneWaveBasis{T}, ψ,
311308
δHψk_plus_q = derivative_wrt_αs(model.positions, α, idx) do positions_αs
312309
PDPψk(basis, positions_αs, psp_groups, kpt, kpt, ψ[ik])
313310
end
314-
-sum( 2occupation[ik][iband] * basis.kweights[ik]
315-
* dot(δψk_plus_q[:, iband], δHψk[:, iband])
316-
+ δoccupation[ik][iband] * basis.kweights[ik]
317-
* 2real(dot(ψk[:, iband], δHψk_plus_q[:, iband]))
318-
for iband=1:size(ψk, 2))
311+
-basis.kweights[ik] * sum(2occupation[ik] .* vec(columnwise_dots(δψk_plus_q, δHψk)) +
312+
δoccupation[ik] .* 2vec(real(columnwise_dots(ψk, δHψk_plus_q))))
319313
end
320314
end
321315
end

0 commit comments

Comments
 (0)