@@ -28,7 +28,13 @@ def __init__(self, nmax=3, lmax=3, rcut=3.5, alpha=2.0,
2828 self .cutoff_function = 'cosine'
2929 self .weight_on = weight_on
3030 self .neighborcalc = neighborlist
31- #return
31+ self .ncoefs = self .nmax * (self .nmax + 1 )// 2 * (self .lmax + 1 )
32+ self .tril_indices = np .tril_indices (self .nmax , k = 0 )
33+ self .ls = np .arange (self .lmax + 1 )
34+ self .norm = np .sqrt (2 * np .sqrt (2 )* np .pi / np .sqrt (2 * self .ls + 1 ))
35+ self .keys = ['keys' , '_nmax' , '_lmax' , '_rcut' , '_alpha' ,
36+ '_cutoff_function' , 'weight_on' , 'neighborcalc' ,
37+ 'ncoefs' , 'ls' , 'norm' , 'tril_indices' ]
3238
3339 def __str__ (self ):
3440 s = "SO3 descriptor with Cutoff: {:6.3f}" .format (self .rcut )
@@ -69,7 +75,7 @@ def nmax(self, nmax):
6975 if nmax < 1 :
7076 raise ValueError ('nmax must be greater than or equal to 1' )
7177 if nmax > 11 :
72- raise ValueError ('nmax > 11 yields complex eigenvalues which will mess up the calculation ' )
78+ raise ValueError ('nmax > 11 yields complex eigenvalues' )
7379 self ._nmax = nmax
7480 else :
7581 raise ValueError ('nmax must be an integer' )
@@ -84,9 +90,7 @@ def lmax(self, lmax):
8490 if lmax < 0 :
8591 raise ValueError ('lmax must be greater than or equal to zero' )
8692 elif lmax > 32 :
87- raise NotImplementedError ('''Currently we only support Wigner-D matrices and spherical harmonics
88- for arguments up to l=32. If you need higher functionality, raise an issue
89- in our Github and we will expand the set of supported functions''' )
93+ raise NotImplementedError ('support a maxmimum l=32 for spherical harmonics' )
9094 self ._lmax = lmax
9195 else :
9296 raise ValueError ('lmax must be an integer' )
@@ -117,17 +121,6 @@ def alpha(self, alpha):
117121 else :
118122 raise ValueError ('alpha must be a float' )
119123
120- @property
121- def derivative (self ):
122- return self ._derivative
123-
124- @derivative .setter
125- def derivative (self , derivative ):
126- if isinstance (derivative , bool ) is True :
127- self ._derivative = derivative
128- else :
129- raise ValueError ('derivative must be a boolean value' )
130-
131124 @property
132125 def cutoff_function (self ):
133126 return self ._cutoff_function
@@ -142,133 +135,130 @@ def clear_memory(self):
142135 '''
143136 attrs = list (vars (self ).keys ())
144137 for attr in attrs :
145- if attr not in { '_nmax' , '_lmax' , '_rcut' , '_alpha' , '_derivative' , '_cutoff_function' , 'weight_on' , 'neighborcalc' } :
138+ if attr not in self . keys :
146139 delattr (self , attr )
147140 return
148141
149- def calculate (self , atoms , atom_ids = None , derivative = False ):
150- '''
151- Calculates the SO(3) power spectrum components of the
152- smoothened atomic neighbor density function
153- for given nmax, lmax, rcut, and alpha.
154-
155- Args:
156- atoms: an ASE atoms object corresponding to the desired
157- atomic arrangement
158- atom_ids:
159- derivative: bool, whether to calculate the gradient of not
160- '''
142+ def init_atoms (self , atoms , atom_ids ):
143+ """
144+ initilize atoms related attributes
145+ """
161146 self ._atoms = atoms
147+ self .natoms = len (atoms )
162148 self .build_neighbor_list (atom_ids )
163- self .initialize_arrays ()
164149
165- ncoefs = self .nmax * (self .nmax + 1 )// 2 * (self .lmax + 1 )
166- tril_indices = np .tril_indices (self .nmax , k = 0 )
150+ def compute_p (self , atoms , atom_ids = None ):
151+ """
152+ Compute the powerspectrum function
167153
168- ls = np .arange (self .lmax + 1 )
169- norm = np .sqrt (2 * np .sqrt (2 )* np .pi / np .sqrt (2 * ls + 1 ))
154+ Args:
155+ atoms: ase atoms object
156+ atom_ids: optional list of atomic indices
170157
171- if derivative :
172- # get expansion coefficients and derivatives
173- cs , dcs = compute_dcs ( self . neighborlist , self . nmax , self . lmax , self . rcut , self . alpha , self . _cutoff_function )
158+ Returns :
159+ p array (N, M)
160+ """
174161
175- # weight cs and dcs
162+ self .init_atoms (atoms , atom_ids )
163+ plist = np .zeros ((len (atoms ), self .ncoefs ), dtype = np .float64 )
164+ if len (self .neighborlist ) > 0 :
165+ cs = compute_cs (self .neighborlist , self .nmax , self .lmax , self .rcut , self .alpha , self ._cutoff_function )
176166 cs *= self .atomic_weights [:, np .newaxis , np .newaxis , np .newaxis ]
177- dcs *= self .atomic_weights [:, np .newaxis , np .newaxis , np .newaxis , np .newaxis ]
178- cs = np .einsum ('inlm,l->inlm' , cs , norm )
179- dcs = np .einsum ('inlmj,l->inlmj' , dcs , norm )
180- #print('cs, dcs', self.neighbor_indices, cs.shape, dcs.shape)
181-
182- # Assign cs and dcs to P and dP
183- # cs: (N_ij, n, l, m) => P (N_i, N_des)
184- # dcs: (N_ij, n, l, m, 3) => dP (N_i, N_j, N_des, 3)
185- # (n, l, m) needs to be merged to 1 dimension
167+ cs = np .einsum ('inlm,l->inlm' , cs , self .norm )
186168
169+ # Get r_ij and compute C*np.conj(C)
187170 for i in range (len (atoms )):
188- # find atoms for which i is the center
189- centers = self .neighbor_indices [:, 0 ] == i
171+ centers = self .neighbor_indices [:,0 ] == i
172+ if len (centers ) > 0 :
173+ ctot = cs [centers ].sum (axis = 0 )
174+ P = np .einsum ('ijk,ljk->ilj' , ctot , np .conj (ctot )).real
175+ plist [i ] = P [self .tril_indices ].flatten ()
176+ return plist
177+
178+ def compute_dpdr (self , atoms , atom_ids = None ):
179+ """
180+ Compute the powerspectrum function
181+
182+ Args:
183+ atoms: ase atoms object
184+ atom_ids: optional list of atomic indices
185+
186+ Returns:
187+ dpdr array (N, N, M, 3) and p array (N, M)
188+ """
190189
190+ self .init_atoms (atoms , atom_ids )
191+ p_list = np .zeros ((self .natoms , self .ncoefs ), dtype = np .float64 )
192+ dp_list = np .zeros ((self .natoms , self .natoms , self .ncoefs , 3 ), dtype = np .float64 )
193+
194+ # get expansion coefficients and derivatives
195+ cs , dcs = compute_dcs (self .neighborlist , self .nmax , self .lmax , self .rcut , self .alpha , self ._cutoff_function )
196+
197+ # weight cs and dcs
198+ cs *= self .atomic_weights [:, np .newaxis , np .newaxis , np .newaxis ]
199+ dcs *= self .atomic_weights [:, np .newaxis , np .newaxis , np .newaxis , np .newaxis ]
200+ cs = np .einsum ('inlm,l->inlm' , cs , self .norm )
201+ dcs = np .einsum ('inlmj,l->inlmj' , dcs , self .norm )
202+ #print('cs, dcs', self.neighbor_indices, cs.shape, dcs.shape)
203+
204+ # Assign cs and dcs to P and dP
205+ # cs: (N_ij, n, l, m) => P (N_i, N_des)
206+ # dcs: (N_ij, n, l, m, 3) => dP (N_i, N_j, N_des, 3)
207+ # (n, l, m) needs to be merged to 1 dimension
208+
209+ for i in range (len (atoms )):
210+ # find atoms for which i is the center
211+ centers = self .neighbor_indices [:, 0 ] == i
212+
213+ if len (centers ) > 0 :
191214 # total up the c array for the center atom
192215 ctot = cs [centers ].sum (axis = 0 ) #(n, l, m)
193216
194217 # power spectrum P = c*c_conj
195218 # eq_3 (n, n', l) eliminate m
196219 P = np .einsum ('ijk, ljk->ilj' , ctot , np .conj (ctot )).real
197-
198- # merge (n, n', l) to 1 dimension
199- self ._plist [i ] = P [tril_indices ].flatten ()
220+ p_list [i ] = P [self .tril_indices ].flatten ()
200221
201222 # gradient of P for each neighbor, eq_26
202223 # (N_ijs, n, n', l, 3)
203224 # dc * c_conj + c * dc_conj
204225 dP = np .einsum ('wijkn,ljk->wiljn' , dcs [centers ], np .conj (ctot ))
205- dP += np .conj (np .transpose (dP , axes = [0 ,2 , 1 , 3 , 4 ]))
226+ dP += np .conj (np .transpose (dP , axes = [0 , 2 , 1 , 3 , 4 ]))
206227 dP = dP .real
207228
208229 #print("shape of P/dP", P.shape, dP.shape)#; import sys; sys.exit()
209230
210- #ijs = self.neighbor_indices[centers]
211- #for _id, j in enumerate(ijs[:, 1]):
212- # self._dplist[i, j, :, :] += dP[_id][tril_indices].flatten().reshape(ncoefs, 3)
213- # # QZ: to check
214- # self._dplist[i, i, :, :] += dP[_id][tril_indices].flatten().reshape(ncoefs, 3)
215-
231+ # QZ: to check
216232 ijs = self .neighbor_indices [centers ]
217- for _id , (i_idx , j_idx ) in enumerate (ijs ):#(ijs[:, 1]):
218- Rij = atoms .positions [j_idx ] - atoms .positions [i_idx ]
219- norm_Rij = np .linalg .norm (Rij )
220- for m in range (len (atoms )):
221- if m != i_idx and m != j_idx :
222- normalization_factor = 0
223- self ._dplist [i , m , :, :] += dP [_id ][tril_indices ].flatten ().reshape (ncoefs , 3 ) * normalization_factor
224- elif m == i_idx :
225- normalization_factor = - 1 / norm_Rij
226- self ._dplist [i , m , :, :] += dP [_id ][tril_indices ].flatten ().reshape (ncoefs , 3 ) * normalization_factor
227- elif m == j_idx :
228- normalization_factor = 1 / norm_Rij
229- self ._dplist [i , m , :, :] += dP [_id ][tril_indices ].flatten ().reshape (ncoefs , 3 ) * normalization_factor
230-
231- x = {'x' :self ._plist ,
232- 'dxdr' :self ._dplist ,
233- 'elements' :list (atoms .symbols )}
234- else :
235- if len (self .neighborlist ) > 0 :
236- cs = compute_cs (self .neighborlist , self .nmax , self .lmax , self .rcut , self .alpha , self ._cutoff_function )
237- cs *= self .atomic_weights [:,np .newaxis ,np .newaxis ,np .newaxis ]
238- cs = np .einsum ('inlm,l->inlm' , cs , norm )
239- # everything good up to here
240- for i in range (len (atoms )):
241- centers = self .neighbor_indices [:,0 ] == i
242- ctot = cs [centers ].sum (axis = 0 )
243- P = np .einsum ('ijk,ljk->ilj' , ctot , np .conj (ctot )).real
244- self ._plist [i ] = P [tril_indices ].flatten ()
245- x = {'x' : self ._plist ,
246- 'dxdr' : None ,
247- 'elements' : list (atoms .symbols )}
248-
249- self .clear_memory ()
250- return x
251-
252- def initialize_arrays (self ):
253- # number of atoms
254- natoms = len (self ._atoms ) #self._atoms)
233+ for _id , j in enumerate (ijs [:, 1 ]):
234+ dp_list [i , j , :, :] += dP [_id ][self .tril_indices ].flatten ().reshape (self .ncoefs , 3 )
235+ dp_list [i , i , :, :] -= dP [_id ][self .tril_indices ].flatten ().reshape (self .ncoefs , 3 )
255236
256- # degree of spherical harmonic expansion
257- lmax = self .lmax
237+ return dp_list , p_list
258238
259- # degree of radial expansion
260- nmax = self .nmax
239+ def calculate (self , atoms , derivative = False ):
240+ '''
241+ API for Calculating the SO(3) power spectrum components of the
242+ smoothened atomic neighbor density function
261243
262- # number of unique power spectrum components
263- # this is given by the triangular elements of
264- # the radial expansion multiplied by the degree
265- # of spherical harmonic expansion (including 0)
266- ncoefs = nmax * (nmax + 1 )// 2 * (lmax + 1 )
244+ Args:
245+ atoms: an ASE atoms object corresponding to the desired
246+ atomic arrangement
247+ derivative: bool, whether to calculate the gradient of not
248+ '''
249+ p_list = None
250+ dp_list = None
251+ if derivative :
252+ dp_list , p_list = self .compute_dpdr (atoms )
253+ else :
254+ p_list = self .compute_p (atoms )
267255
268- self ._plist = np .zeros ((natoms , ncoefs ), dtype = np .float64 )
269- self ._dplist = np .zeros ((natoms , natoms , ncoefs , 3 ), dtype = np .float64 )
256+ x = {'x' : p_list ,
257+ 'dxdr' : dp_list ,
258+ 'elements' : list (atoms .symbols )}
259+ self .clear_memory ()
260+ return x
270261
271- return
272262
273263 def build_neighbor_list (self , atom_ids = None ):
274264 '''
@@ -323,11 +313,10 @@ def build_neighbor_list(self, atom_ids=None):
323313 self .neighborlist = np .array (neighbors , dtype = np .float64 )
324314 self .atomic_weights = np .array (atomic_weights , dtype = np .int64 )
325315 self .neighbor_indices = neighbor_indices
326- return
327316
328317def Cosine (Rij , Rc , derivative = False ):
329318 # Rij is the norm
330- if derivative is False :
319+ if not derivative :
331320 result = 0.5 * (np .cos (np .pi * Rij / Rc ) + 1. )
332321 else :
333322 result = - 0.5 * np .pi / Rc * np .sin (np .pi * Rij / Rc )
@@ -640,9 +629,11 @@ def compute_dcs(pos, nmax, lmax, rcut, alpha, cutoff):
640629 der = options .der
641630
642631 start1 = time .time ()
643- f = SO3 (nmax = nmax , lmax = lmax , rcut = rcut , alpha = alpha , cutoff_function = 'cosine' )
644- x = f .calculate (test , derivative = True )
632+ f = SO3 (nmax = nmax , lmax = lmax , rcut = rcut , alpha = alpha )
633+ x = f .calculate (test , derivative = der )
645634 start2 = time .time ()
635+ print (f )
646636 print ('x' , x ['x' ])
647- print ('dxdr' , x ['dxdr' ])
637+ # print('dxdr', x['dxdr'])
648638 print ('calculation time {}' .format (start2 - start1 ))
639+ print (f .compute_p (test ))
0 commit comments