Skip to content

Commit c501bfd

Browse files
committed
Optimise inverse fourier transform
1 parent 96bf0ea commit c501bfd

File tree

4 files changed

+46
-32
lines changed

4 files changed

+46
-32
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2525
Spglib = "f761d5c5-86db-4880-b97f-9680a7cccfb5"
2626
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2727
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
28+
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
2829
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
2930
UnitfulAtomic = "a7773ee8-282e-5fa2-be4e-bd808c38a91a"
3031

@@ -53,6 +54,7 @@ SafeTestsets = "0.1"
5354
Spglib = "1.0.1"
5455
StaticArrays = "1.9"
5556
StructArrays = "0.7"
57+
Tullio = "0.3.8"
5658
Unitful = "1.25"
5759
UnitfulAtomic = "1"
5860
julia = "1.12"

src/Quoll.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using ArgCheck
66
using LinearAlgebra
77
using StaticArrays
88
using OffsetArrays
9+
using Tullio
910
using Dictionaries
1011
using AutoHashEquals
1112
using AxisKeys

src/conversions/canonical/bsparse.jl

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,41 +34,48 @@ end
3434

3535
# TODO: write another method but with shifts (or alternatively write the method with shifts
3636
# but force the compiler to remove shifts for matching SH convention)
37+
# TODO: could use out_keydata here instead
3738
function fourier_transform_data!(out_operator::DenseOperator, in_operator::BSparseOperator, phases_k)
3839
in_keydata = get_keydata(in_operator)
3940
in_sparsity = get_sparsity(in_operator)
4041
out_data = get_data(out_operator)
4142
out_basisset = get_basisset(out_operator)
42-
out_float = get_float(out_operator)
4343

4444
ilocal2iglobal = get_ilocal2iglobal(in_sparsity)
45-
atom2offset = get_atom2offset(out_basisset)
45+
atom2basis = get_atom2basis(out_basisset)
46+
47+
# atom2offset = get_atom2offset(out_basisset)
48+
# out_float = get_float(out_operator)
4649

4750
for ((iat, jat), in_block) in pairs(in_keydata)
4851

4952
phases_kij = phases_k[ilocal2iglobal[(iat, jat)]]
50-
atom2offset_i = atom2offset[iat]
51-
atom2offset_j = atom2offset[jat]
5253

53-
for jb in axes(in_block, 2)
54-
jb_dense = jb + atom2offset_j
55-
for ib in axes(in_block, 1)
56-
ib_dense = ib + atom2offset_i
54+
in_matrix = reshape(in_block, size(in_block, 1) * size(in_block, 2), size(in_block, 3))
55+
out_data[atom2basis[iat], atom2basis[jat]] = in_matrix * phases_kij
5756

58-
tmp = zero(out_float)
59-
@inbounds for iR in axes(in_block, 3)
60-
tmp += in_block[ib, jb, iR] * phases_kij[iR]
61-
end
57+
# atom2offset_i = atom2offset[iat]
58+
# atom2offset_j = atom2offset[jat]
6259

63-
out_data[ib_dense, jb_dense] = tmp
64-
end
65-
end
60+
# for jb in axes(in_block, 2)
61+
# jb_dense = jb + atom2offset_j
62+
# for ib in axes(in_block, 1)
63+
# ib_dense = ib + atom2offset_i
64+
65+
# tmp = zero(out_float)
66+
# @inbounds for iR in eachindex(phases_kij)
67+
# # tmp += in_block[ib, jb, iR] * phases_kij[iR]
68+
# tmp = muladd(in_block[ib, jb, iR], phases_kij[iR], tmp)
69+
# end
70+
71+
# out_data[ib_dense, jb_dense] = tmp
72+
# end
73+
# end
6674

6775
# @tullio tmp[ib,jb] := in_block[ib,jb,R] * phases_kij[R]
6876
# out_data[atom2basis[iat], atom2basis[jat]] = tmp
6977
end
7078

71-
# TODO: I could use Hermitian array instead, probably directly in build stage
7279
if in_sparsity.hermitian
7380
for jb_dense in 1:size(out_data, 1)
7481
for ib_dense in 1:jb_dense
@@ -80,30 +87,35 @@ end
8087

8188
# Only a contribution to out_operator because in_operator contains only a single k-point
8289
function inv_fourier_transform_data!(out_operator::BSparseOperator, in_operator::DenseOperator, phases_k, weight)
83-
out_basisset = get_basisset(out_operator)
8490
out_keydata = get_keydata(out_operator)
8591
out_sparsity = get_sparsity(out_operator)
92+
in_basisset = get_basisset(out_operator)
8693
in_data = get_data(in_operator)
8794

8895
ilocal2iglobal = get_ilocal2iglobal(out_sparsity)
89-
atom2offset = get_atom2offset(out_basisset)
96+
atom2basis = get_atom2basis(in_basisset)
97+
98+
# atom2offset = get_atom2offset(in_basisset)
9099

91100
for ((iat, jat), out_block) in pairs(out_keydata)
92101

93-
# Assuming phases is a vector here (i.e. for a single k-point)
94-
inv_phases_kij = conj(phases_k[ilocal2iglobal[(iat, jat)]])
95-
atom2offset_i = atom2offset[iat]
96-
atom2offset_j = atom2offset[jat]
102+
inv_phases_kij = conj(@view(phases_k[ilocal2iglobal[(iat, jat)]]))
103+
in_data_ij = @view(in_data[atom2basis[iat], atom2basis[jat]])
104+
105+
@tullio out_block[ib,jb,iR] += weight * real(in_data_ij[ib,jb] * inv_phases_kij[iR])
106+
107+
# atom2offset_i = atom2offset[iat]
108+
# atom2offset_j = atom2offset[jat]
97109

98-
for iR in axes(out_block, 3)
99-
for jb in axes(out_block, 2)
100-
jb_dense = jb + atom2offset_j
101-
for ib in axes(out_block, 1)
102-
ib_dense = ib + atom2offset_i
103-
out_block[ib, jb, iR] += weight * real(in_data[ib_dense, jb_dense] * inv_phases_kij[iR])
104-
end
105-
end
106-
end
110+
# for iR in axes(out_block, 3)
111+
# for jb in axes(out_block, 2)
112+
# jb_dense = jb + atom2offset_j
113+
# for ib in axes(out_block, 1)
114+
# ib_dense = ib + atom2offset_i
115+
# out_block[ib, jb, iR] += weight * real(in_data[ib_dense, jb_dense] * inv_phases_kij[iR])
116+
# end
117+
# end
118+
# end
107119
end
108120

109121
end

src/projections/common.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ function perform_core_projection(operators, projected_basis, kgrid::KGrid, my_ik
9393
return v_operators
9494
end
9595

96-
# TODO: if not views, then sometimes can result in redundant allocations
9796
function core_valence_partition(operator::DenseOperator, core_mask::BitVector, valence_mask::BitVector)
9897
O = get_data(operator)
9998
O₁₁ = view(O, core_mask, core_mask)

0 commit comments

Comments
 (0)