Skip to content

Commit ebdf998

Browse files
committed
update get_tabular_rep and SO3
1 parent 62cac32 commit ebdf998

File tree

4 files changed

+214
-118
lines changed

4 files changed

+214
-118
lines changed

.github/workflows/tests.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ jobs:
5959
- name: Test_supergroup
6060
run: pytest tests/test_supergroup.py
6161

62+
- name: Test_so3
63+
run: pytest tests/test_SO3.py
64+
6265
#- name: Test_xrd
6366
# run: pytest tests/test_xrd.py
6467

pyxtal/__init__.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3625,7 +3625,7 @@ def get_xtal_string(self, dicts=None, header=None):
36253625

36263626
def get_tabular_representations(
36273627
self,
3628-
N_max=30,
3628+
N_max=200,
36293629
N_wp=8,
36303630
normalize=False,
36313631
perturb=False,
@@ -3650,16 +3650,24 @@ def get_tabular_representations(
36503650
a list of equivalent 1D tabular representations
36513651
"""
36523652
reps = []
3653+
36533654
# To prevent the explosion of big multiplicity number
3654-
min_wp = 20 if len(self.atom_sites) <= 5 else int(
3655-
np.power(1000000.0, 1 / len(self.atom_sites)))
3655+
if len(self.atom_sites) <= 1:
3656+
min_wp = 192
3657+
elif len(self.atom_sites) <= 2:
3658+
min_wp = 100
3659+
elif len(self.atom_sites) <= 4:
3660+
min_wp = 20
3661+
else:
3662+
min_wp = int(np.power(100000.0, 1 / len(self.atom_sites)))
3663+
36563664
sites_mul = [range(min([min_wp, site.wp.multiplicity]))
36573665
for site in self.atom_sites]
36583666
ids = list(itertools.product(*sites_mul))
36593667
if len(ids) > N_max:
36603668
ids = self.random_state.choice(ids, N_max)
36613669

3662-
print(f"N_reps {len(ids)} from ", self.get_xtal_string())
3670+
print(f"N_reps {len(ids)}/{N_max}: ", self.get_xtal_string())
36633671
for sites_id in ids:
36643672
rep = self.get_tabular_representation(
36653673
sites_id, normalize, N_wp, perturb,

pyxtal/lego/SO3.py

Lines changed: 105 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,13 @@ def __init__(self, nmax=3, lmax=3, rcut=3.5, alpha=2.0,
2828
self.cutoff_function = 'cosine'
2929
self.weight_on = weight_on
3030
self.neighborcalc = neighborlist
31-
#return
31+
self.ncoefs = self.nmax*(self.nmax+1)//2*(self.lmax+1)
32+
self.tril_indices = np.tril_indices(self.nmax, k=0)
33+
self.ls = np.arange(self.lmax+1)
34+
self.norm = np.sqrt(2*np.sqrt(2)*np.pi/np.sqrt(2*self.ls+1))
35+
self.keys = ['keys', '_nmax', '_lmax', '_rcut', '_alpha',
36+
'_cutoff_function', 'weight_on', 'neighborcalc',
37+
'ncoefs', 'ls', 'norm', 'tril_indices']
3238

3339
def __str__(self):
3440
s = "SO3 descriptor with Cutoff: {:6.3f}".format(self.rcut)
@@ -69,7 +75,7 @@ def nmax(self, nmax):
6975
if nmax < 1:
7076
raise ValueError('nmax must be greater than or equal to 1')
7177
if nmax > 11:
72-
raise ValueError('nmax > 11 yields complex eigenvalues which will mess up the calculation')
78+
raise ValueError('nmax > 11 yields complex eigenvalues')
7379
self._nmax = nmax
7480
else:
7581
raise ValueError('nmax must be an integer')
@@ -84,9 +90,7 @@ def lmax(self, lmax):
8490
if lmax < 0:
8591
raise ValueError('lmax must be greater than or equal to zero')
8692
elif lmax > 32:
87-
raise NotImplementedError('''Currently we only support Wigner-D matrices and spherical harmonics
88-
for arguments up to l=32. If you need higher functionality, raise an issue
89-
in our Github and we will expand the set of supported functions''')
93+
raise NotImplementedError('support a maxmimum l=32 for spherical harmonics')
9094
self._lmax = lmax
9195
else:
9296
raise ValueError('lmax must be an integer')
@@ -117,17 +121,6 @@ def alpha(self, alpha):
117121
else:
118122
raise ValueError('alpha must be a float')
119123

120-
@property
121-
def derivative(self):
122-
return self._derivative
123-
124-
@derivative.setter
125-
def derivative(self, derivative):
126-
if isinstance(derivative, bool) is True:
127-
self._derivative = derivative
128-
else:
129-
raise ValueError('derivative must be a boolean value')
130-
131124
@property
132125
def cutoff_function(self):
133126
return self._cutoff_function
@@ -142,133 +135,130 @@ def clear_memory(self):
142135
'''
143136
attrs = list(vars(self).keys())
144137
for attr in attrs:
145-
if attr not in {'_nmax', '_lmax', '_rcut', '_alpha', '_derivative', '_cutoff_function', 'weight_on', 'neighborcalc'}:
138+
if attr not in self.keys:
146139
delattr(self, attr)
147140
return
148141

149-
def calculate(self, atoms, atom_ids=None, derivative=False):
150-
'''
151-
Calculates the SO(3) power spectrum components of the
152-
smoothened atomic neighbor density function
153-
for given nmax, lmax, rcut, and alpha.
154-
155-
Args:
156-
atoms: an ASE atoms object corresponding to the desired
157-
atomic arrangement
158-
atom_ids:
159-
derivative: bool, whether to calculate the gradient of not
160-
'''
142+
def init_atoms(self, atoms, atom_ids):
143+
"""
144+
initilize atoms related attributes
145+
"""
161146
self._atoms = atoms
147+
self.natoms = len(atoms)
162148
self.build_neighbor_list(atom_ids)
163-
self.initialize_arrays()
164149

165-
ncoefs = self.nmax*(self.nmax+1)//2*(self.lmax+1)
166-
tril_indices = np.tril_indices(self.nmax, k=0)
150+
def compute_p(self, atoms, atom_ids=None):
151+
"""
152+
Compute the powerspectrum function
167153
168-
ls = np.arange(self.lmax+1)
169-
norm = np.sqrt(2*np.sqrt(2)*np.pi/np.sqrt(2*ls+1))
154+
Args:
155+
atoms: ase atoms object
156+
atom_ids: optional list of atomic indices
170157
171-
if derivative:
172-
# get expansion coefficients and derivatives
173-
cs, dcs = compute_dcs(self.neighborlist, self.nmax, self.lmax, self.rcut, self.alpha, self._cutoff_function)
158+
Returns:
159+
p array (N, M)
160+
"""
174161

175-
# weight cs and dcs
162+
self.init_atoms(atoms, atom_ids)
163+
plist = np.zeros((len(atoms), self.ncoefs), dtype=np.float64)
164+
if len(self.neighborlist) > 0:
165+
cs = compute_cs(self.neighborlist, self.nmax, self.lmax, self.rcut, self.alpha, self._cutoff_function)
176166
cs *= self.atomic_weights[:, np.newaxis, np.newaxis, np.newaxis]
177-
dcs *= self.atomic_weights[:, np.newaxis, np.newaxis, np.newaxis, np.newaxis]
178-
cs = np.einsum('inlm,l->inlm', cs, norm)
179-
dcs = np.einsum('inlmj,l->inlmj', dcs, norm)
180-
#print('cs, dcs', self.neighbor_indices, cs.shape, dcs.shape)
181-
182-
# Assign cs and dcs to P and dP
183-
# cs: (N_ij, n, l, m) => P (N_i, N_des)
184-
# dcs: (N_ij, n, l, m, 3) => dP (N_i, N_j, N_des, 3)
185-
# (n, l, m) needs to be merged to 1 dimension
167+
cs = np.einsum('inlm,l->inlm', cs, self.norm)
186168

169+
# Get r_ij and compute C*np.conj(C)
187170
for i in range(len(atoms)):
188-
# find atoms for which i is the center
189-
centers = self.neighbor_indices[:, 0] == i
171+
centers = self.neighbor_indices[:,0] == i
172+
if len(centers) > 0:
173+
ctot = cs[centers].sum(axis=0)
174+
P = np.einsum('ijk,ljk->ilj', ctot, np.conj(ctot)).real
175+
plist[i] = P[self.tril_indices].flatten()
176+
return plist
177+
178+
def compute_dpdr(self, atoms, atom_ids=None):
179+
"""
180+
Compute the powerspectrum function
181+
182+
Args:
183+
atoms: ase atoms object
184+
atom_ids: optional list of atomic indices
185+
186+
Returns:
187+
dpdr array (N, N, M, 3) and p array (N, M)
188+
"""
190189

190+
self.init_atoms(atoms, atom_ids)
191+
p_list = np.zeros((self.natoms, self.ncoefs), dtype=np.float64)
192+
dp_list = np.zeros((self.natoms, self.natoms, self.ncoefs, 3), dtype=np.float64)
193+
194+
# get expansion coefficients and derivatives
195+
cs, dcs = compute_dcs(self.neighborlist, self.nmax, self.lmax, self.rcut, self.alpha, self._cutoff_function)
196+
197+
# weight cs and dcs
198+
cs *= self.atomic_weights[:, np.newaxis, np.newaxis, np.newaxis]
199+
dcs *= self.atomic_weights[:, np.newaxis, np.newaxis, np.newaxis, np.newaxis]
200+
cs = np.einsum('inlm,l->inlm', cs, self.norm)
201+
dcs = np.einsum('inlmj,l->inlmj', dcs, self.norm)
202+
#print('cs, dcs', self.neighbor_indices, cs.shape, dcs.shape)
203+
204+
# Assign cs and dcs to P and dP
205+
# cs: (N_ij, n, l, m) => P (N_i, N_des)
206+
# dcs: (N_ij, n, l, m, 3) => dP (N_i, N_j, N_des, 3)
207+
# (n, l, m) needs to be merged to 1 dimension
208+
209+
for i in range(len(atoms)):
210+
# find atoms for which i is the center
211+
centers = self.neighbor_indices[:, 0] == i
212+
213+
if len(centers) > 0:
191214
# total up the c array for the center atom
192215
ctot = cs[centers].sum(axis=0) #(n, l, m)
193216

194217
# power spectrum P = c*c_conj
195218
# eq_3 (n, n', l) eliminate m
196219
P = np.einsum('ijk, ljk->ilj', ctot, np.conj(ctot)).real
197-
198-
# merge (n, n', l) to 1 dimension
199-
self._plist[i] = P[tril_indices].flatten()
220+
p_list[i] = P[self.tril_indices].flatten()
200221

201222
# gradient of P for each neighbor, eq_26
202223
# (N_ijs, n, n', l, 3)
203224
# dc * c_conj + c * dc_conj
204225
dP = np.einsum('wijkn,ljk->wiljn', dcs[centers], np.conj(ctot))
205-
dP += np.conj(np.transpose(dP, axes=[0,2,1,3,4]))
226+
dP += np.conj(np.transpose(dP, axes=[0, 2, 1, 3, 4]))
206227
dP = dP.real
207228

208229
#print("shape of P/dP", P.shape, dP.shape)#; import sys; sys.exit()
209230

210-
#ijs = self.neighbor_indices[centers]
211-
#for _id, j in enumerate(ijs[:, 1]):
212-
# self._dplist[i, j, :, :] += dP[_id][tril_indices].flatten().reshape(ncoefs, 3)
213-
# # QZ: to check
214-
# self._dplist[i, i, :, :] += dP[_id][tril_indices].flatten().reshape(ncoefs, 3)
215-
231+
# QZ: to check
216232
ijs = self.neighbor_indices[centers]
217-
for _id, (i_idx, j_idx) in enumerate(ijs):#(ijs[:, 1]):
218-
Rij = atoms.positions[j_idx] - atoms.positions[i_idx]
219-
norm_Rij = np.linalg.norm(Rij)
220-
for m in range(len(atoms)):
221-
if m != i_idx and m != j_idx:
222-
normalization_factor = 0
223-
self._dplist[i, m, :, :] += dP[_id][tril_indices].flatten().reshape(ncoefs, 3) * normalization_factor
224-
elif m == i_idx:
225-
normalization_factor = -1 / norm_Rij
226-
self._dplist[i, m, :, :] += dP[_id][tril_indices].flatten().reshape(ncoefs, 3) * normalization_factor
227-
elif m == j_idx:
228-
normalization_factor = 1 / norm_Rij
229-
self._dplist[i, m, :, :] += dP[_id][tril_indices].flatten().reshape(ncoefs, 3) * normalization_factor
230-
231-
x = {'x':self._plist,
232-
'dxdr':self._dplist,
233-
'elements':list(atoms.symbols)}
234-
else:
235-
if len(self.neighborlist) > 0:
236-
cs = compute_cs(self.neighborlist, self.nmax, self.lmax, self.rcut, self.alpha, self._cutoff_function)
237-
cs *= self.atomic_weights[:,np.newaxis,np.newaxis,np.newaxis]
238-
cs = np.einsum('inlm,l->inlm', cs, norm)
239-
# everything good up to here
240-
for i in range(len(atoms)):
241-
centers = self.neighbor_indices[:,0] == i
242-
ctot = cs[centers].sum(axis=0)
243-
P = np.einsum('ijk,ljk->ilj', ctot, np.conj(ctot)).real
244-
self._plist[i] = P[tril_indices].flatten()
245-
x = {'x': self._plist,
246-
'dxdr': None,
247-
'elements': list(atoms.symbols)}
248-
249-
self.clear_memory()
250-
return x
251-
252-
def initialize_arrays(self):
253-
# number of atoms
254-
natoms = len(self._atoms) #self._atoms)
233+
for _id, j in enumerate(ijs[:, 1]):
234+
dp_list[i, j, :, :] += dP[_id][self.tril_indices].flatten().reshape(self.ncoefs, 3)
235+
dp_list[i, i, :, :] -= dP[_id][self.tril_indices].flatten().reshape(self.ncoefs, 3)
255236

256-
# degree of spherical harmonic expansion
257-
lmax = self.lmax
237+
return dp_list, p_list
258238

259-
# degree of radial expansion
260-
nmax = self.nmax
239+
def calculate(self, atoms, derivative=False):
240+
'''
241+
API for Calculating the SO(3) power spectrum components of the
242+
smoothened atomic neighbor density function
261243
262-
# number of unique power spectrum components
263-
# this is given by the triangular elements of
264-
# the radial expansion multiplied by the degree
265-
# of spherical harmonic expansion (including 0)
266-
ncoefs = nmax*(nmax+1)//2*(lmax+1)
244+
Args:
245+
atoms: an ASE atoms object corresponding to the desired
246+
atomic arrangement
247+
derivative: bool, whether to calculate the gradient of not
248+
'''
249+
p_list = None
250+
dp_list = None
251+
if derivative:
252+
dp_list, p_list = self.compute_dpdr(atoms)
253+
else:
254+
p_list = self.compute_p(atoms)
267255

268-
self._plist = np.zeros((natoms, ncoefs), dtype=np.float64)
269-
self._dplist = np.zeros((natoms, natoms, ncoefs, 3), dtype=np.float64)
256+
x = {'x': p_list,
257+
'dxdr': dp_list,
258+
'elements': list(atoms.symbols)}
259+
self.clear_memory()
260+
return x
270261

271-
return
272262

273263
def build_neighbor_list(self, atom_ids=None):
274264
'''
@@ -323,11 +313,10 @@ def build_neighbor_list(self, atom_ids=None):
323313
self.neighborlist = np.array(neighbors, dtype=np.float64)
324314
self.atomic_weights = np.array(atomic_weights, dtype=np.int64)
325315
self.neighbor_indices = neighbor_indices
326-
return
327316

328317
def Cosine(Rij, Rc, derivative=False):
329318
# Rij is the norm
330-
if derivative is False:
319+
if not derivative:
331320
result = 0.5 * (np.cos(np.pi * Rij / Rc) + 1.)
332321
else:
333322
result = -0.5 * np.pi / Rc * np.sin(np.pi * Rij / Rc)
@@ -640,9 +629,11 @@ def compute_dcs(pos, nmax, lmax, rcut, alpha, cutoff):
640629
der = options.der
641630

642631
start1 = time.time()
643-
f = SO3(nmax=nmax, lmax=lmax, rcut=rcut, alpha=alpha, cutoff_function='cosine')
644-
x = f.calculate(test, derivative=True)
632+
f = SO3(nmax=nmax, lmax=lmax, rcut=rcut, alpha=alpha)
633+
x = f.calculate(test, derivative=der)
645634
start2 = time.time()
635+
print(f)
646636
print('x', x['x'])
647-
print('dxdr', x['dxdr'])
637+
#print('dxdr', x['dxdr'])
648638
print('calculation time {}'.format(start2-start1))
639+
print(f.compute_p(test))

0 commit comments

Comments
 (0)