66 Union ,
77)
88
9+ import array_api_compat
910import numpy as np
1011
1112from deepmd .dpmodel import (
1213 DEFAULT_PRECISION ,
1314 PRECISION_DICT ,
1415 NativeOP ,
1516)
17+ from deepmd .dpmodel .common import (
18+ get_xp_precision ,
19+ to_numpy_array ,
20+ )
1621from deepmd .dpmodel .utils import (
1722 EmbeddingNet ,
1823 EnvMat ,
2530from deepmd .dpmodel .utils .update_sel import (
2631 UpdateSel ,
2732)
28- from deepmd .env import (
29- GLOBAL_NP_FLOAT_PRECISION ,
30- )
3133from deepmd .utils .data_system import (
3234 DeepmdDataSystem ,
3335)
@@ -122,26 +124,28 @@ def __init__(
122124 # order matters, placed after the assignment of self.ntypes
123125 self .reinit_exclude (exclude_types )
124126 self .trainable = trainable
127+ self .sel_cumsum = [0 , * np .cumsum (self .sel ).tolist ()]
125128
126129 in_dim = 1 # not considiering type embedding
127- self . embeddings = NetworkCollection (
130+ embeddings = NetworkCollection (
128131 ntypes = self .ntypes ,
129132 ndim = 2 ,
130133 network_type = "embedding_network" ,
131134 )
132135 for ii , embedding_idx in enumerate (
133- itertools .product (range (self .ntypes ), repeat = self . embeddings .ndim )
136+ itertools .product (range (self .ntypes ), repeat = embeddings .ndim )
134137 ):
135- self . embeddings [embedding_idx ] = EmbeddingNet (
138+ embeddings [embedding_idx ] = EmbeddingNet (
136139 in_dim ,
137140 self .neuron ,
138141 self .activation_function ,
139142 self .resnet_dt ,
140143 self .precision ,
141144 seed = child_seed (self .seed , ii ),
142145 )
146+ self .embeddings = embeddings
143147 self .env_mat = EnvMat (self .rcut , self .rcut_smth , protection = self .env_protection )
144- self .nnei = np . sum (self .sel )
148+ self .nnei = sum (self .sel )
145149 self .davg = np .zeros (
146150 [self .ntypes , self .nnei , 4 ], dtype = PRECISION_DICT [self .precision ]
147151 )
@@ -299,20 +303,22 @@ def call(
299303 The smooth switch function.
300304 """
301305 del mapping
306+ xp = array_api_compat .array_namespace (coord_ext , atype_ext , nlist )
302307 # nf x nloc x nnei x 4
303308 rr , diff , ww = self .env_mat .call (
304309 coord_ext , atype_ext , nlist , self .davg , self .dstd
305310 )
306311 nf , nloc , nnei , _ = rr .shape
307- sec = np . append ([ 0 ], np . cumsum ( self .sel ))
312+ sec = self .sel_cumsum
308313
309314 ng = self .neuron [- 1 ]
310- result = np .zeros ([nf * nloc , ng ], dtype = PRECISION_DICT [ self .precision ] )
315+ result = xp .zeros ([nf * nloc , ng ], dtype = get_xp_precision ( xp , self .precision ) )
311316 exclude_mask = self .emask .build_type_exclude_mask (nlist , atype_ext )
312317 # merge nf and nloc axis, so for type_one_side == False,
313318 # we don't require atype is the same in all frames
314- exclude_mask = exclude_mask .reshape (nf * nloc , nnei )
315- rr = rr .reshape (nf * nloc , nnei , 4 )
319+ exclude_mask = xp .reshape (exclude_mask , (nf * nloc , nnei ))
320+ rr = xp .reshape (rr , (nf * nloc , nnei , 4 ))
321+ rr = xp .astype (rr , get_xp_precision (xp , self .precision ))
316322
317323 for embedding_idx in itertools .product (
318324 range (self .ntypes ), repeat = self .embeddings .ndim
@@ -325,23 +331,26 @@ def call(
325331 # nfnl x nt_i x 3
326332 rr_i = rr [:, sec [ti ] : sec [ti + 1 ], 1 :]
327333 mm_i = exclude_mask [:, sec [ti ] : sec [ti + 1 ]]
328- rr_i = rr_i * mm_i [:, :, None ]
334+ rr_i = rr_i * xp . astype ( mm_i [:, :, None ], rr_i . dtype )
329335 # nfnl x nt_j x 3
330336 rr_j = rr [:, sec [tj ] : sec [tj + 1 ], 1 :]
331337 mm_j = exclude_mask [:, sec [tj ] : sec [tj + 1 ]]
332- rr_j = rr_j * mm_j [:, :, None ]
338+ rr_j = rr_j * xp . astype ( mm_j [:, :, None ], rr_j . dtype )
333339 # nfnl x nt_i x nt_j
334- env_ij = np .einsum ("ijm,ikm->ijk" , rr_i , rr_j )
340+ # env_ij = np.einsum("ijm,ikm->ijk", rr_i, rr_j)
341+ env_ij = xp .sum (rr_i [:, :, None , :] * rr_j [:, None , :, :], axis = - 1 )
335342 # nfnl x nt_i x nt_j x 1
336343 env_ij_reshape = env_ij [:, :, :, None ]
337344 # nfnl x nt_i x nt_j x ng
338345 gg = self .embeddings [embedding_idx ].call (env_ij_reshape )
339346 # nfnl x nt_i x nt_j x ng
340- res_ij = np .einsum ("ijk,ijkm->im" , env_ij , gg )
347+ # res_ij = np.einsum("ijk,ijkm->im", env_ij, gg)
348+ res_ij = xp .sum (env_ij [:, :, :, None ] * gg , axis = (1 , 2 ))
341349 res_ij = res_ij * (1.0 / float (nei_type_i ) / float (nei_type_j ))
342350 result += res_ij
343351 # nf x nloc x ng
344- result = result .reshape (nf , nloc , ng ).astype (GLOBAL_NP_FLOAT_PRECISION )
352+ result = xp .reshape (result , (nf , nloc , ng ))
353+ result = xp .astype (result , get_xp_precision (xp , "global" ))
345354 return result , None , None , None , ww
346355
347356 def serialize (self ) -> dict :
@@ -369,8 +378,8 @@ def serialize(self) -> dict:
369378 "exclude_types" : self .exclude_types ,
370379 "env_protection" : self .env_protection ,
371380 "@variables" : {
372- "davg" : self .davg ,
373- "dstd" : self .dstd ,
381+ "davg" : to_numpy_array ( self .davg ) ,
382+ "dstd" : to_numpy_array ( self .dstd ) ,
374383 },
375384 "type_map" : self .type_map ,
376385 "trainable" : self .trainable ,
0 commit comments