Skip to content

Commit 22ddb12

Browse files
committed
fix the bug for dsdx
1 parent 0d2a100 commit 22ddb12

File tree

3 files changed

+122
-16
lines changed

3 files changed

+122
-16
lines changed

pyxtal/lego/SO3.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def compute_dpdr_5d(self, atoms):
272272
pair_ids = neigh_ids[self.neighbor_indices[:, 0] == i]
273273
if len(pair_ids) > 0:
274274
ctot = cs[pair_ids].sum(axis=0) #(n, l, m)
275+
dctot = dcs[pair_ids].sum(axis=0)
275276
# power spectrum P = c*c_conj
276277
# eq_3 (n, n', l) eliminate m
277278
P = np.einsum('ijk, ljk->ilj', ctot, np.conj(ctot)).real
@@ -287,12 +288,16 @@ def compute_dpdr_5d(self, atoms):
287288
# (N_ijs, n, n', l, 3)
288289
# dc * c_conj + c * dc_conj
289290
dP = np.einsum('ijkn, ljk->iljn', dcs[pair_id], np.conj(ctot))
290-
dP += np.conj(np.transpose(dP, axes=[1, 0, 2, 3]))
291+
dP += np.einsum('ijkn, ljk->iljn', np.conj(dcs[pair_id]), ctot)
292+
#dP += np.conj(np.transpose(dP, axes=[1, 0, 2, 3]))
293+
#dP += np.einsum('ijkn, ljk->iljn', np.conj(dctot), cs[pair_id])
294+
#dP += np.einsum('ijkn, ljk->iljn', dctot, np.conj(cs[pair_id]))
295+
291296
dP = dP.real[self.tril_indices].flatten().reshape(self.ncoefs, 3)
292297
#print(cs[pair_id].shape, dcs[pair_id].shape, dP.shape)
293298

294299
dp_list[i, j, :, :, cell_id] += dP
295-
dp_list[i, i, :, :, cell_id] -= dP
300+
dp_list[i, i, :, :, 13] -= dP
296301

297302
return dp_list, p_list
298303

tests/test_SO3.py

Lines changed: 114 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,75 @@
66
from ase import Atoms
77
from ase.build import bulk, sort
88

9+
def calculate_S(atoms, P_ref):
10+
P = calculator.compute_p(atoms)
11+
S = np.sum((P - P_ref)**2)
12+
return S
13+
14+
15+
def numerical_dSdx(x, xtal, P_ref, eps=1e-4):
16+
if type(x) == list: x = np.array(x)
17+
18+
xtal.update_from_1d_rep(x)
19+
atoms = xtal.to_ase() #* 2
20+
S0 = calculate_S(atoms, P_ref)
21+
dSdx = np.zeros(len(x))
22+
for i in range(len(x)):
23+
x0 = x.copy()
24+
x0[i] += eps
25+
xtal.update_from_1d_rep(x0)
26+
atoms = xtal.to_ase() #* 2
27+
S1 = calculate_S(atoms, P_ref)
28+
29+
x0 = x.copy()
30+
x0[i] -= eps
31+
xtal.update_from_1d_rep(x0)
32+
atoms = xtal.to_ase() #* 2
33+
S2 = calculate_S(atoms, P_ref)
34+
35+
dSdx[i] = 0.5*(S1-S2)/eps
36+
return dSdx
37+
38+
39+
def calculate_dSdx_supercell(x, xtal, P_ref, eps=1e-4):
40+
41+
xtal.update_from_1d_rep(x)
42+
atoms = xtal.to_ase() #* 2
43+
44+
dPdr, P = calculator.compute_dpdr_5d(atoms)
45+
46+
# Compute dSdr [N, M] [N, N, M, 3, 27] => [N, 3, 27]
47+
dSdr = np.einsum("ik, ijklm -> jlm", 2*(P - P_ref), dPdr)
48+
49+
# Get supercell positions
50+
ref_pos = np.repeat(atoms.positions[:, :, np.newaxis], 27, axis=2)
51+
for cell in range(27):
52+
x1, y1, z1 = cell // 9 - 1, (cell // 3) % 3 - 1, cell % 3 - 1
53+
ref_pos[:, :, cell] += np.array([x1, y1, z1]) @ atoms.cell
54+
55+
# Compute drdx via numerical func
56+
drdx = np.zeros([len(atoms), 3, 27, len(x)])
57+
58+
xtal0 = xtal.copy()
59+
for i in range(len(x)):
60+
x0 = x.copy()
61+
x0[i] += eps
62+
xtal0.update_from_1d_rep(x0)
63+
atoms = xtal0.to_ase()
64+
65+
# Get supercell positions
66+
pos = np.repeat(atoms.positions[:, :, np.newaxis], 27, axis=2)
67+
for cell in range(27):
68+
x1, y1, z1 = cell // 9 - 1, (cell // 3) % 3 - 1, cell % 3 - 1
69+
pos[:, :, cell] += np.array([x1, y1, z1]) @ atoms.cell
70+
71+
drdx[:, :, :, i] += (pos - ref_pos)/eps
72+
73+
# [N, 3, 27] [N, 3, 27, H] => H
74+
dSdx = np.einsum("ijk, ijkl -> l", dSdr, drdx)
75+
return dSdx
76+
77+
978
def get_rotated_cluster(struc, angle=0, axis='x'):
1079
s_new = struc.copy()
1180
s_new.rotate(angle, axis)
@@ -28,15 +97,15 @@ def get_perturbed_xtal(struc, p0, p1, eps):
2897
p_struc.set_positions(pos)
2998
return p_struc
3099

31-
def get_dPdR_xtal(xtal, nmax, lmax, rc, eps):
32-
p0 = SO3(nmax=nmax, lmax=lmax, rcut=rc).calculate(xtal, derivative=True)
100+
def get_dPdR_xtal(xtal, eps):
101+
p0 = calculator.calculate(xtal, derivative=True)
33102
shp = p0['x'].shape
34103
array1 = p0['dxdr']
35104

36105
for j in range(shp[0]):
37106
for k in range(3):
38107
struc = get_perturbed_xtal(xtal, j, k, eps)
39-
p1 = SO3(nmax=nmax, lmax=lmax, rcut=rc).calculate(struc)
108+
p1 = calculator.calculate(struc)
40109
array2 = (p1['x'] - p0['x'])/eps
41110
#if np.linalg.norm(array2) > 1e-2: print(j, k, array2)
42111
if not np.allclose(array1[:, j, :, k], array2, atol=1e-4):
@@ -45,21 +114,29 @@ def get_dPdR_xtal(xtal, nmax, lmax, rc, eps):
45114

46115
# Descriptors Parameters
47116
eps = 1e-8
48-
rc = 2.80
117+
rc1 = 1.8
118+
rc2 = 3.5
49119
nmax, lmax = 2, 2
120+
calculator = SO3(nmax=nmax, lmax=lmax, rcut=rc1)
121+
calculator0 = SO3(nmax=nmax, lmax=lmax, rcut=rc2)
50122

51123
# NaCl cluster
52124
cluster = bulk('NaCl', crystalstructure='rocksalt', a=5.691694, cubic=True)
53125
cluster = sort(cluster, tags=[0, 4, 1, 5, 2, 6, 3, 7])
54126
cluster.set_pbc((0,0,0))
55127
cluster = get_rotated_cluster(cluster, angle=0.1) # Must rotate
56128

129+
xtal = pyxtal()
130+
xtal.from_prototype('graphite')
131+
atoms = xtal.to_ase()
132+
P_ref = calculator.compute_p(atoms)[0]
133+
57134
# Diamond
58135
class TestCluster(unittest.TestCase):
59136
struc = get_rotated_cluster(cluster)
60-
p0 = SO3(nmax=nmax, lmax=lmax, rcut=rc).calculate(struc, derivative=True)
137+
p0 = calculator0.calculate(struc, derivative=True)
61138
struc = get_rotated_cluster(cluster, 10, 'x')
62-
p1 = SO3(nmax=nmax, lmax=lmax, rcut=rc).calculate(struc)
139+
p1 = calculator0.calculate(struc)
63140

64141
def test_SO3_rotation_variance(self):
65142
array1 = self.p0['x']
@@ -73,7 +150,7 @@ def test_dPdR_vs_numerical(self):
73150
for j in range(shp[0]):
74151
for k in range(3):
75152
struc = get_perturbed_cluster(cluster, j, k, eps)
76-
p2 = SO3(nmax=nmax, lmax=lmax, rcut=rc).calculate(struc)
153+
p2 = calculator0.calculate(struc)
77154
array2 = (p2['x'] - self.p0['x'])/eps
78155
assert(np.allclose(array1[:,j,:,k], array2, atol=1e-3))
79156

@@ -82,30 +159,53 @@ class TestXtal(unittest.TestCase):
82159
def test_dPdR_diamond(self):
83160
c = pyxtal()
84161
c.from_prototype('diamond')
85-
get_dPdR_xtal(c.to_ase(), nmax, lmax, rc, eps)
162+
get_dPdR_xtal(c.to_ase(), eps)
86163

87164
def test_dPdR_graphite(self):
88165
c = pyxtal()
89166
c.from_prototype('graphite')
90-
get_dPdR_xtal(c.to_ase(), nmax, lmax, rc, eps)
167+
get_dPdR_xtal(c.to_ase(), eps)
91168

92169
def test_dPdR_random(self):
93170
x = [ 7.952, 2.606, 0.592, 0.926, 0.608, 0.307]
94171
c = pyxtal()
95172
c.from_spg_wps_rep(179, ['6a', '6a', '6a', '6a'], x)
96-
get_dPdR_xtal(c.to_ase(), nmax, lmax, rc, eps)
173+
get_dPdR_xtal(c.to_ase(), eps)
97174

98175
def test_dPdR_random_P(self):
99176
x = [ 7.952, 2.606, 0.592, 0.926, 0.608, 0.307]
100177
c = pyxtal()
101178
c.from_spg_wps_rep(179, ['6a', '6a', '6a', '6a'], x)
102179
atoms = c.to_ase()
103-
f = SO3(nmax=nmax, lmax=lmax, rcut=rc)
104-
p0 = f.compute_p(atoms)
105-
_, p1 = f.compute_dpdr(atoms)
106-
_, p2 = f.compute_dpdr_5d(atoms)
180+
p0 = calculator.compute_p(atoms)
181+
_, p1 = calculator.compute_dpdr(atoms)
182+
_, p2 = calculator.compute_dpdr_5d(atoms)
107183
assert(np.allclose(p0, p1, atol=1e-3))
108184
assert(np.allclose(p0, p2, atol=1e-3))
109185

186+
class TestSimilarity(unittest.TestCase):
187+
188+
def test_sim_diamond(self):
189+
x = [3.0]
190+
c = pyxtal()
191+
c.from_spg_wps_rep(227, ['8a'], x, ['C'])
192+
atoms = c.to_ase()
193+
x = c.get_1d_rep_x()
194+
dSdx1 = numerical_dSdx(x, c, P_ref)
195+
dSdx2 = calculate_dSdx_supercell(x, c, P_ref)
196+
#print(dSdx1, dSdx2)
197+
assert(np.allclose(dSdx1, dSdx2, rtol=1e-1, atol=1e+1))
198+
199+
def test_dPdR_random(self):
200+
#x = [ 7.952, 2.606, 0.592, 0.926, 0.608, 0.307]
201+
x = [9.55, 2.60, 0.48, 0.88, 0.76, 0.36]
202+
c = pyxtal()
203+
c.from_spg_wps_rep(179, ['6a', '6a', '6a', '6a'], x)
204+
atoms = c.to_ase()
205+
dSdx1 = numerical_dSdx(x, c, P_ref)
206+
dSdx2 = calculate_dSdx_supercell(x, c, P_ref)
207+
#print(dSdx1, dSdx2)
208+
assert(np.allclose(dSdx1, dSdx2, rtol=1e-1, atol=1e+1))
209+
110210
if __name__ == "__main__":
111211
unittest.main()

tests/test_lego.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def test_opt_xtal(self):
4343
xtal = pyxtal()
4444
xtal.from_spg_wps_rep(spg, wps, x, ['C']*len(wps))
4545
xtal, sim, _ = builder1.optimize_xtal(xtal, add_db=False)
46+
#print(xtal.get_1d_rep_x())
4647
assert sim < 1e-2
4748

4849
def test_opt_xtal2(self):

0 commit comments

Comments
 (0)