Skip to content

Commit 331e873

Browse files
Fix symmetries & symmetry-breaking with ForwardDiff (#1082)
--------- Co-authored-by: Bruno Ploumhans <[email protected]>
1 parent b5832f5 commit 331e873

File tree

4 files changed

+241
-23
lines changed

4 files changed

+241
-23
lines changed

src/SymOp.jl

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,49 @@ function Base.:*(op1::SymOp, op2::SymOp)
6666
end
6767
Base.inv(op::SymOp) = SymOp(inv(op.W), -op.W\op.w)
6868

69+
is_approx_in(symop, group; kwargs...) = any(s -> isapprox(s, symop; kwargs...), group)
6970
function check_group(symops::Vector; kwargs...)
70-
is_approx_in_symops(s1) = any(s -> isapprox(s, s1; kwargs...), symops)
71+
is_approx_in_symops(s1) = is_approx_in(s1, symops; kwargs...)
7172
is_approx_in_symops(one(SymOp)) || error("check_group: no identity element")
7273
for s in symops
7374
if !is_approx_in_symops(inv(s))
7475
error("check_group: symop $s with inverse $(inv(s)) is not in the group")
7576
end
7677
for s2 in symops
77-
if !is_approx_in_symops(s*s2) || !is_approx_in_symops(s2*s)
78-
error("check_group: product is not stable")
78+
if !is_approx_in_symops(s*s2)
79+
error("check_group: product is not stable: $(s*s2) is not in the group")
7980
end
8081
end
8182
end
8283
symops
8384
end
85+
86+
function complete_symop_group(symops; maxiter=10, kwargs...)
87+
completed_group = Vector(symops)
88+
89+
function add_to_group(to_add, s1)
90+
if !is_approx_in(s1, completed_group; kwargs...) && !is_approx_in(s1, to_add; kwargs...)
91+
push!(to_add, s1)
92+
end
93+
end
94+
95+
for it = 1:maxiter
96+
if it == maxiter
97+
error("Could not complete group in $maxiter iterations")
98+
end
99+
to_add = []
100+
# Identity always needs to be there!
101+
add_to_group(to_add, one(SymOp))
102+
for s in completed_group
103+
add_to_group(to_add, inv(s))
104+
for t in completed_group
105+
add_to_group(to_add, s*t)
106+
end
107+
end
108+
if isempty(to_add)
109+
return completed_group
110+
end
111+
append!(completed_group, to_add)
112+
end
113+
DFTK.check_group(completed_group) # returns the completed group
114+
end

src/symmetry.jl

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -364,31 +364,40 @@ function symmetrize_stresses(basis::PlaneWaveBasis, stresses)
364364
symmetrize_stresses(basis.model, stresses; basis.symmetries)
365365
end
366366

367+
"""
368+
Find the symmetry preimage of `position`, returning the corresponding index in `positions_group`.
369+
"""
370+
function find_symmetry_preimage(positions_group, position, symop;
371+
tol_symmetry=SYMMETRY_TOLERANCE)
372+
# see (A.27) of https://arxiv.org/pdf/0906.2569.pdf
373+
# (but careful that our symmetries are r -> Wr+w, not R(r+f))
374+
other_at = symop.W \ (position - symop.w)
375+
376+
# Find the index of the atom to which idx is mapped to by the symmetry operation.
377+
# To avoid issues due to numerical noise we compute the deviations from being
378+
# an integer shift (thus equivalent by translational symmetry) for all atoms in
379+
# the group and pick the smallest one.
380+
smallest_deviation, i_other_at = findmin(positions_group) do at
381+
δat = at - other_at
382+
maximum(abs, δat - round.(δat))
383+
end
384+
# Note, that without a fudging factor this occasionally fails:
385+
@assert smallest_deviation < 10tol_symmetry
386+
387+
i_other_at
388+
end
389+
367390
function symmetrize_forces(positions::AbstractVector, atom_groups::AbstractVector, forces;
368391
symmetries, tol_symmetry=SYMMETRY_TOLERANCE)
369392
symmetrized_forces = zero(forces)
370393
for group in atom_groups, symop in symmetries
371394
positions_group = positions[group]
372-
W, w = symop.W, symop.w
373395
for (idx, position) in enumerate(positions_group)
374-
# see (A.27) of https://arxiv.org/pdf/0906.2569.pdf
375-
# (but careful that our symmetries are r -> Wr+w, not R(r+f))
376-
other_at = W \ (position - w)
377-
378-
# Find the index of the atom to which idx is mapped to by the symmetry operation.
379-
# To avoid issues due to numerical noise we compute the deviations from being
380-
# an integer shift (thus equivalent by translational symmetry) for all atoms in
381-
# the group and pick the smallest one.
382-
smallest_deviation, i_other_at = findmin(positions_group) do at
383-
δat = at - other_at
384-
maximum(abs, δat - round.(δat))
385-
end
386-
# Note, that without a fudging factor this occasionally fails:
387-
@assert smallest_deviation < 10tol_symmetry
396+
i_other_at = find_symmetry_preimage(positions_group, position, symop; tol_symmetry)
388397

389398
# (A.27) is in Cartesian coordinates, and since Wcart is orthogonal,
390399
# Fsymcart = Wcart * Fcart <=> Fsymred = inv(Wred') Fred
391-
symmetrized_forces[group[idx]] += inv(W') * forces[group[i_other_at]]
400+
symmetrized_forces[group[idx]] += inv(symop.W') * forces[group[i_other_at]]
392401
end
393402
end
394403
symmetrized_forces / length(symmetries)

src/workarounds/forwarddiff_rules.jl

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,76 @@ end
7171

7272
next_working_fft_size(::Type{<:Dual}, size::Int) = size
7373

74-
# determine symmetry operations only from primal lattice values
74+
# determine symmetry operations only from primal values, then filter out symmetries broken by dual part
7575
function symmetry_operations(lattice::AbstractMatrix{<:Dual},
76-
atoms, positions, magnetic_moments=[]; kwargs...)
76+
atoms, positions, magnetic_moments=[];
77+
tol_symmetry=SYMMETRY_TOLERANCE, kwargs...)
7778
positions_value = [ForwardDiff.value.(pos) for pos in positions]
78-
symmetry_operations(ForwardDiff.value.(lattice), atoms, positions_value,
79-
magnetic_moments; kwargs...)
79+
symmetries = symmetry_operations(ForwardDiff.value.(lattice), atoms,
80+
positions_value, magnetic_moments;
81+
tol_symmetry, kwargs...)
82+
remove_dual_broken_symmetries(lattice, atoms, positions, symmetries; tol_symmetry)
83+
end
84+
85+
function remove_dual_broken_symmetries(lattice, atoms, positions,
86+
symmetries; tol_symmetry=SYMMETRY_TOLERANCE)
87+
filter(symmetries) do symmetry
88+
!is_symmetry_broken_by_dual(lattice, atoms, positions, symmetry; tol_symmetry)
89+
end
90+
end
91+
92+
"""
93+
Return `true` if a symmetry that holds for the primal part is broken by
94+
a perturbation in the lattice or in the positions, `false` otherwise.
95+
"""
96+
function is_symmetry_broken_by_dual(lattice, atoms, positions, symmetry::SymOp; tol_symmetry)
97+
# For any lattice atom at position x, W*x + w should be in the lattice.
98+
# In cartesian coordinates, with a perturbed lattice A = A₀ + εA₁,
99+
# this means that for any atom position xcart in the unit cell and any 3 integers u,
100+
# there should be an atom at position ycart and 3 integers v such that:
101+
# Wcart * (xcart + A*u) + wcart = ycart + A*v
102+
# where
103+
# Wcart = A₀ * W * A₀⁻¹; note that W is an integer matrix
104+
# wcart = A₀ * w.
105+
#
106+
# In relative coordinates this gives:
107+
# A⁻¹ * Wcart * (A*x + A*u) + A⁻¹ * wcart = y + v (*)
108+
#
109+
# The strategy is then to check that:
110+
# 1. A⁻¹ * Wcart * A is still an integer matrix (i.e. no dual part),
111+
# such that any change in u is easily compensated for in v.
112+
# 2. The primal component of (*), i.e. with ε=0, is already known to hold.
113+
# Since v does not have a dual component, it is enough to check that
114+
# the dual part of the following is 0:
115+
# A⁻¹ * Wcart * A*x + A⁻¹ * wcart - y
116+
117+
lattice_primal = ForwardDiff.value.(lattice)
118+
W = (compute_inverse_lattice(lattice) * lattice_primal
119+
* symmetry.W * compute_inverse_lattice(lattice_primal) * lattice)
120+
w = compute_inverse_lattice(lattice) * lattice_primal * symmetry.w
121+
122+
is_dual_nonzero(x::AbstractArray) = any(x) do xi
123+
maximum(abs, ForwardDiff.partials(xi)) >= tol_symmetry
124+
end
125+
# Check 1.
126+
if is_dual_nonzero(W)
127+
return true
128+
end
129+
130+
atom_groups = [findall(Ref(pot) .== atoms) for pot in Set(atoms)]
131+
for group in atom_groups
132+
positions_group = positions[group]
133+
for position in positions_group
134+
i_other_at = find_symmetry_preimage(positions_group, position, symmetry; tol_symmetry)
135+
136+
# Check 2. with x = positions_group[i_other_at] and y = position
137+
if is_dual_nonzero(positions_group[i_other_at] + inv(W) * (w - position))
138+
return true
139+
end
140+
end
141+
end
142+
143+
false
80144
end
81145

82146
function _is_well_conditioned(A::AbstractArray{<:Dual}; kwargs...)

test/forwarddiff.jl

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,117 @@ end
260260
derivative_fd = ForwardDiff.derivative(compute_force, 0.0)
261261
@test norm(derivative_ε - derivative_fd) < 1e-4
262262
end
263+
264+
@testitem "Symmetries broken by perturbation are filtered out" tags=[:dont_test_mpi] begin
265+
using DFTK
266+
using ForwardDiff
267+
using LinearAlgebra
268+
269+
lattice = [2. 0. 0.; 0. 1. 0.; 0. 0. 1.]
270+
positions = [[0., 0., 0.], [0.5, 0., 0.]]
271+
gauss = ElementGaussian(1.0, 0.5)
272+
atoms = [gauss, gauss]
273+
atom_groups = [findall(Ref(pot) .== atoms) for pot in Set(atoms)]
274+
275+
# Select some "interesting" subset of the symmetries
276+
# Rotation in the yz plane by 90 degrees
277+
rotyz = SymOp([1 0 0; 0 0 1; 0 -1 0], [0., 0., 0.])
278+
mirroryz = rotyz * rotyz
279+
# Mirror y
280+
mirrory = SymOp([1 0 0; 0 -1 0; 0 0 1], [0., 0., 0.])
281+
# Translation by 0.5 in the x direction
282+
transx = SymOp(diagm([1, 1, 1]), [0.5, 0., 0.])
283+
284+
# Generate the full group
285+
symmetries_full = DFTK.complete_symop_group([rotyz, mirrory, transx])
286+
@test length(symmetries_full) == 16
287+
288+
DFTK._check_symmetries(symmetries_full, lattice, atom_groups, positions)
289+
290+
function check_symmetries(filtered_symmetries, expected_symmetries)
291+
expected_symmetries = DFTK.complete_symop_group(expected_symmetries)
292+
293+
@test length(filtered_symmetries) == length(expected_symmetries)
294+
for s in expected_symmetries
295+
@test DFTK.is_approx_in(s, filtered_symmetries)
296+
end
297+
for s in filtered_symmetries
298+
@test DFTK.is_approx_in(s, expected_symmetries)
299+
end
300+
end
301+
302+
# Instantiate Dual to test with perturbations
303+
ε = ForwardDiff.Dual{ForwardDiff.Tag{Nothing, Float64}}(0.0, 1.0)
304+
305+
@testset "Atom movement" begin
306+
# Moving the second atom should break the transx symmetry, but not the others
307+
positions_modified = [[0., 0., 0.], [0.5 + ε, 0., 0.]]
308+
symmetries_filtered = DFTK.remove_dual_broken_symmetries(lattice, atoms, positions_modified, symmetries_full)
309+
310+
@test length(symmetries_filtered) == 8
311+
check_symmetries(symmetries_filtered, [rotyz, mirrory])
312+
end
313+
314+
@testset "Lattice strain" begin
315+
# Straining the lattice along z should break the rotyz symmetry, but not the others
316+
# In particular it should not break mirroryz which is normally generated by rotyz.
317+
lattice_modified = diagm([2., 1., 1. + ε])
318+
symmetries_filtered = DFTK.remove_dual_broken_symmetries(lattice_modified, atoms, positions, symmetries_full)
319+
320+
@test length(symmetries_filtered) == 8
321+
check_symmetries(symmetries_filtered, [mirrory, transx, mirroryz])
322+
end
323+
324+
@testset "Atom movement + lattice strain" begin
325+
# Only the mirrory and mirroryz symmetries should be left
326+
positions_modified = [[0., 0., 0.], [0.5 + ε, 0., 0.]]
327+
lattice_modified = diagm([2., 1., 1. + ε])
328+
symmetries_filtered = DFTK.remove_dual_broken_symmetries(lattice_modified, atoms, positions_modified, symmetries_full)
329+
330+
@test length(symmetries_filtered) == 4
331+
check_symmetries(symmetries_filtered, [mirrory, mirroryz])
332+
end
333+
end
334+
335+
@testitem "Symmetry-breaking perturbation using ForwardDiff" #=
336+
=# tags=[:dont_test_mpi] setup=[TestCases] begin
337+
using DFTK
338+
using ForwardDiff
339+
using LinearAlgebra
340+
aluminium = TestCases.aluminium
341+
342+
@testset for perturbation in (:lattice, :positions)
343+
function run_scf(ε)
344+
lattice = if perturbation == :lattice
345+
v = ε * [0., 0., 0., 0., 0., 1.]
346+
DFTK.voigt_strain_to_full(v) * aluminium.lattice
347+
else
348+
# Lattice has to be a dual for position perturbations to work
349+
aluminium.lattice * one(typeof(ε))
350+
end
351+
pos = if perturbation == :lattice
352+
aluminium.positions
353+
else
354+
map(enumerate(aluminium.positions)) do (i, x)
355+
i == 1 ? x + ε * [1., 0, 0] : x
356+
end
357+
end
358+
359+
model = model_DFT(lattice, aluminium.atoms, pos;
360+
functionals=LDA(), temperature=1e-2,
361+
smearing=Smearing.Gaussian())
362+
basis = PlaneWaveBasis(model; Ecut=5, kgrid=[2, 2, 2])
363+
self_consistent_field(basis; tol=1e-10)
364+
end
365+
366+
δρ = ForwardDiff.derivative-> run_scf(ε).ρ, 0.)
367+
368+
h = 1e-6
369+
scfres1 = run_scf(-h)
370+
scfres2 = run_scf(+h)
371+
δρ_finitediff = (scfres2.ρ - scfres1.ρ) / 2h
372+
373+
rtol = 1e-4
374+
@test norm(δρ - δρ_finitediff, 1) < rtol * norm(δρ, 1)
375+
end
376+
end

0 commit comments

Comments
 (0)