@@ -139,7 +139,7 @@ def clear_memory(self):
139139 delattr (self , attr )
140140 return
141141
142- def init_atoms (self , atoms , atom_ids ):
142+ def init_atoms (self , atoms , atom_ids = None ):
143143 """
144144 initilize atoms related attributes
145145 """
@@ -210,7 +210,7 @@ def compute_dpdr(self, atoms, atom_ids=None):
210210 # find atoms for which i is the center
211211 centers = self .neighbor_indices [:, 0 ] == i
212212
213- if len (centers ) > 0 :
213+ if len (self . neighbor_indices [ centers ] ) > 0 :
214214 # total up the c array for the center atom
215215 ctot = cs [centers ].sum (axis = 0 ) #(n, l, m)
216216
@@ -236,7 +236,68 @@ def compute_dpdr(self, atoms, atom_ids=None):
236236
237237 return dp_list , p_list
238238
239- def calculate (self , atoms , derivative = False ):
239+ def compute_dpdr_5d (self , atoms ):
240+ """
241+ Compute the powerspectrum function with respect to supercell
242+
243+ Args:
244+ atoms: ase atoms object
245+ atom_ids: optional list of atomic indices
246+
247+ Returns:
248+ dpdr array (N, N, M, 3, 27) and p array (N, M)
249+ """
250+
251+ self .init_atoms (atoms )
252+ p_list = np .zeros ((self .natoms , self .ncoefs ), dtype = np .float64 )
253+ dp_list = np .zeros ((self .natoms , self .natoms , self .ncoefs , 3 , 27 ), dtype = np .float64 )
254+
255+ # get expansion coefficients and derivatives
256+ cs , dcs = compute_dcs (self .neighborlist , self .nmax , self .lmax , self .rcut , self .alpha , self ._cutoff_function )
257+
258+ # weight cs and dcs
259+ cs *= self .atomic_weights [:, np .newaxis , np .newaxis , np .newaxis ]
260+ dcs *= self .atomic_weights [:, np .newaxis , np .newaxis , np .newaxis , np .newaxis ]
261+ cs = np .einsum ('inlm,l->inlm' , cs , self .norm )
262+ dcs = np .einsum ('inlmj,l->inlmj' , dcs , self .norm )
263+ #print('cs, dcs', self.neighbor_indices, cs.shape, dcs.shape)
264+
265+ # Assign cs and dcs to P and dP
266+ # cs: (N_ij, n, l, m) => P (N_i, N_des)
267+ # dcs: (N_ij, n, l, m, 3) => dP (N_i, N_j, N_des, 3)
268+ # (n, l, m) needs to be merged to 1 dimension
269+ neigh_ids = np .arange (len (self .neighbor_indices ))
270+ for i in range (len (atoms )):
271+ # find atoms for which i is the center
272+ pair_ids = neigh_ids [self .neighbor_indices [:, 0 ] == i ]
273+
274+ # loop over each pair
275+ for pair_id in pair_ids :
276+ (_ , j , x , y , z ) = self .neighbor_indices [pair_id ]
277+ # map from (x, y, z) to (0, 27)
278+ cell_id = (x + 1 ) * 9 + (y + 1 ) * 3 + z + 1
279+
280+ # power spectrum P = c*c_conj
281+ # eq_3 (n, n', l) eliminate m
282+ P = np .einsum ('ijk, ljk->ilj' , cs [pair_id ], np .conj (cs [pair_id ])).real
283+ p_list [i ] = P [self .tril_indices ].flatten ()
284+
285+ # gradient of P for each neighbor, eq_26
286+ # (N_ijs, n, n', l, 3)
287+ # dc * c_conj + c * dc_conj
288+ dP = np .einsum ('ijkn, ljk->iljn' , dcs [pair_id ], np .conj (cs [pair_id ]))
289+ dP += np .conj (np .transpose (dP , axes = [1 , 0 , 2 , 3 ]))
290+ dP = dP .real [self .tril_indices ].flatten ().reshape (self .ncoefs , 3 )
291+ #print(cs[pair_id].shape, dcs[pair_id].shape, dP.shape)
292+
293+ dp_list [i , j , :, :, cell_id ] += dP
294+ dp_list [i , i , :, :, cell_id ] -= dP
295+
296+ return dp_list , p_list
297+
298+
299+
300+ def calculate (self , atoms , atom_ids = None , derivative = False ):
240301 '''
241302 API for Calculating the SO(3) power spectrum components of the
242303 smoothened atomic neighbor density function
@@ -249,9 +310,9 @@ def calculate(self, atoms, derivative=False):
249310 p_list = None
250311 dp_list = None
251312 if derivative :
252- dp_list , p_list = self .compute_dpdr (atoms )
313+ dp_list , p_list = self .compute_dpdr (atoms , atom_ids )
253314 else :
254- p_list = self .compute_p (atoms )
315+ p_list = self .compute_p (atoms , atom_ids )
255316
256317 x = {'x' : p_list ,
257318 'dxdr' : dp_list ,
@@ -295,7 +356,7 @@ def build_neighbor_list(self, atom_ids=None):
295356 #print(indices); import sys; sys.exit()
296357 temp_indices .append (indices )
297358 for j , offset in zip (indices , offsets ):
298- pos = atoms .positions [j ] + np .dot (offset ,atoms .get_cell ()) - center_atom
359+ pos = atoms .positions [j ] + np .dot (offset , atoms .get_cell ()) - center_atom
299360 # to prevent division by zero
300361 if np .sum (np .abs (pos )) < 1e-3 : pos += 0.001
301362 center_atoms .append (center_atom )
@@ -305,7 +366,7 @@ def build_neighbor_list(self, atom_ids=None):
305366 else :
306367 factor = 1
307368 atomic_weights .append (factor * atoms [j ].number )
308- neighbor_indices .append ([i ,j ])
369+ neighbor_indices .append ([i , j , * offset ])
309370
310371 neighbor_indices = np .array (neighbor_indices , dtype = np .int64 )
311372
0 commit comments