55 Union ,
66)
77
8+ import array_api_compat
89import numpy as np
910
1011from deepmd .dpmodel import (
1112 PRECISION_DICT ,
1213 NativeOP ,
1314)
15+ from deepmd .dpmodel .array_api import (
16+ xp_take_along_axis ,
17+ )
18+ from deepmd .dpmodel .common import (
19+ get_xp_precision ,
20+ to_numpy_array ,
21+ )
1422from deepmd .dpmodel .utils import (
1523 EmbeddingNet ,
1624 EnvMat ,
2634from deepmd .dpmodel .utils .update_sel import (
2735 UpdateSel ,
2836)
29- from deepmd .env import (
30- GLOBAL_NP_FLOAT_PRECISION ,
31- )
3237from deepmd .utils .data_system import (
3338 DeepmdDataSystem ,
3439)
@@ -318,11 +323,15 @@ def call(
318323 sw
319324 The smooth switch function.
320325 """
326+ xp = array_api_compat .array_namespace (nlist , coord_ext , atype_ext )
321327 del mapping
322328 nf , nloc , nnei = nlist .shape
323- nall = coord_ext .reshape (nf , - 1 ).shape [1 ] // 3
329+ nall = xp .reshape (coord_ext , ( nf , - 1 ) ).shape [1 ] // 3
324330 # nf x nall x tebd_dim
325- atype_embd_ext = self .type_embedding .call ()[atype_ext ]
331+ atype_embd_ext = xp .reshape (
332+ xp .take (self .type_embedding .call (), xp .reshape (atype_ext , [- 1 ]), axis = 0 ),
333+ (nf , nall , self .tebd_dim ),
334+ )
326335 # nfnl x tebd_dim
327336 atype_embd = atype_embd_ext [:, :nloc , :]
328337 grrg , g2 , h2 , rot_mat , sw = self .se_ttebd (
@@ -334,8 +343,8 @@ def call(
334343 )
335344 # nf x nloc x (ng + tebd_dim)
336345 if self .concat_output_tebd :
337- grrg = np . concatenate (
338- [grrg , atype_embd .reshape (nf , nloc , self .tebd_dim )], axis = - 1
346+ grrg = xp . concat (
347+ [grrg , xp .reshape (atype_embd , ( nf , nloc , self .tebd_dim ) )], axis = - 1
339348 )
340349 return grrg , rot_mat , None , None , sw
341350
@@ -368,8 +377,8 @@ def serialize(self) -> dict:
368377 "env_protection" : obj .env_protection ,
369378 "smooth" : self .smooth ,
370379 "@variables" : {
371- "davg" : obj ["davg" ],
372- "dstd" : obj ["dstd" ],
380+ "davg" : to_numpy_array ( obj ["davg" ]) ,
381+ "dstd" : to_numpy_array ( obj ["dstd" ]) ,
373382 },
374383 "trainable" : self .trainable ,
375384 }
@@ -491,33 +500,35 @@ def __init__(
491500 else :
492501 self .embd_input_dim = 1
493502
494- self . embeddings = NetworkCollection (
503+ embeddings = NetworkCollection (
495504 ndim = 0 ,
496505 ntypes = self .ntypes ,
497506 network_type = "embedding_network" ,
498507 )
499- self . embeddings [0 ] = EmbeddingNet (
508+ embeddings [0 ] = EmbeddingNet (
500509 self .embd_input_dim ,
501510 self .neuron ,
502511 self .activation_function ,
503512 self .resnet_dt ,
504513 self .precision ,
505514 seed = child_seed (seed , 0 ),
506515 )
516+ self .embeddings = embeddings
507517 if self .tebd_input_mode in ["strip" ]:
508- self . embeddings_strip = NetworkCollection (
518+ embeddings_strip = NetworkCollection (
509519 ndim = 0 ,
510520 ntypes = self .ntypes ,
511521 network_type = "embedding_network" ,
512522 )
513- self . embeddings_strip [0 ] = EmbeddingNet (
523+ embeddings_strip [0 ] = EmbeddingNet (
514524 self .tebd_dim_input ,
515525 self .neuron ,
516526 self .activation_function ,
517527 self .resnet_dt ,
518528 self .precision ,
519529 seed = child_seed (seed , 1 ),
520530 )
531+ self .embeddings_strip = embeddings_strip
521532 else :
522533 self .embeddings_strip = None
523534
@@ -652,82 +663,85 @@ def call(
652663 atype_embd_ext : Optional [np .ndarray ] = None ,
653664 mapping : Optional [np .ndarray ] = None ,
654665 ):
666+ xp = array_api_compat .array_namespace (nlist , coord_ext , atype_ext )
655667 # nf x nloc x nnei x 4
656668 dmatrix , diff , sw = self .env_mat .call (
657669 coord_ext , atype_ext , nlist , self .mean , self .stddev
658670 )
659671 nf , nloc , nnei , _ = dmatrix .shape
660672 exclude_mask = self .emask .build_type_exclude_mask (nlist , atype_ext )
661673 # nfnl x nnei
662- exclude_mask = exclude_mask .reshape (nf * nloc , nnei )
674+ exclude_mask = xp .reshape (exclude_mask , ( nf * nloc , nnei ) )
663675 # nfnl x nnei
664- nlist = nlist .reshape (nf * nloc , nnei )
665- nlist = np .where (exclude_mask , nlist , - 1 )
676+ nlist = xp .reshape (nlist , ( nf * nloc , nnei ) )
677+ nlist = xp .where (exclude_mask , nlist , xp . full_like ( nlist , - 1 ) )
666678 # nfnl x nnei
667679 nlist_mask = nlist != - 1
668680 # nfnl x nnei x 1
669- sw = np .where (nlist_mask [:, :, None ], sw .reshape (nf * nloc , nnei , 1 ), 0.0 )
681+ sw = xp .where (
682+ nlist_mask [:, :, None ],
683+ xp .reshape (sw , (nf * nloc , nnei , 1 )),
684+ xp .zeros ((nf * nloc , nnei , 1 ), dtype = sw .dtype ),
685+ )
670686
671687 # nfnl x nnei x 4
672- dmatrix = dmatrix .reshape (nf * nloc , nnei , 4 )
688+ dmatrix = xp .reshape (dmatrix , ( nf * nloc , nnei , 4 ) )
673689 # nfnl x nnei x 4
674690 rr = dmatrix
675- rr = rr * exclude_mask [:, :, None ]
691+ rr = rr * xp . astype ( exclude_mask [:, :, None ], rr . dtype )
676692 # nfnl x nt_i x 3
677693 rr_i = rr [:, :, 1 :]
678694 # nfnl x nt_j x 3
679695 rr_j = rr [:, :, 1 :]
680696 # nfnl x nt_i x nt_j
681- env_ij = np .einsum ("ijm,ikm->ijk" , rr_i , rr_j )
697+ # env_ij = np.einsum("ijm,ikm->ijk", rr_i, rr_j)
698+ env_ij = xp .sum (rr_i [:, :, None , :] * rr_j [:, None , :, :], axis = - 1 )
682699 # nfnl x nt_i x nt_j x 1
683- ss = np . expand_dims ( env_ij , axis = - 1 )
700+ ss = env_ij [..., None ]
684701
685- nlist_masked = np .where (nlist_mask , nlist , 0 )
686- index = np .tile (nlist_masked .reshape (nf , - 1 , 1 ), (1 , 1 , self .tebd_dim ))
702+ nlist_masked = xp .where (nlist_mask , nlist , xp . zeros_like ( nlist ) )
703+ index = xp .tile (xp .reshape (nlist_masked , ( nf , - 1 , 1 ) ), (1 , 1 , self .tebd_dim ))
687704 # nfnl x nnei x tebd_dim
688- atype_embd_nlist = np .take_along_axis (atype_embd_ext , index , axis = 1 ).reshape (
689- nf * nloc , nnei , self .tebd_dim
705+ atype_embd_nlist = xp_take_along_axis (atype_embd_ext , index , axis = 1 )
706+ atype_embd_nlist = xp .reshape (
707+ atype_embd_nlist , (nf * nloc , nnei , self .tebd_dim )
690708 )
691709 # nfnl x nt_i x nt_j x tebd_dim
692- nlist_tebd_i = np .tile (
693- np .expand_dims (atype_embd_nlist , axis = 2 ), [1 , 1 , self .nnei , 1 ]
694- )
695- nlist_tebd_j = np .tile (
696- np .expand_dims (atype_embd_nlist , axis = 1 ), [1 , self .nnei , 1 , 1 ]
697- )
710+ nlist_tebd_i = xp .tile (atype_embd_nlist [:, :, None , :], (1 , 1 , self .nnei , 1 ))
711+ nlist_tebd_j = xp .tile (atype_embd_nlist [:, None , :, :], (1 , self .nnei , 1 , 1 ))
698712 ng = self .neuron [- 1 ]
699713
700714 if self .tebd_input_mode in ["concat" ]:
701715 # nfnl x nt_i x nt_j x (1 + tebd_dim * 2)
702- ss = np . concatenate ([ss , nlist_tebd_i , nlist_tebd_j ], axis = - 1 )
716+ ss = xp . concat ([ss , nlist_tebd_i , nlist_tebd_j ], axis = - 1 )
703717 # nfnl x nt_i x nt_j x ng
704718 gg = self .cal_g (ss , 0 )
705719 elif self .tebd_input_mode in ["strip" ]:
706720 # nfnl x nt_i x nt_j x ng
707721 gg_s = self .cal_g (ss , 0 )
708722 assert self .embeddings_strip is not None
709723 # nfnl x nt_i x nt_j x (tebd_dim * 2)
710- tt = np . concatenate ([nlist_tebd_i , nlist_tebd_j ], axis = - 1 )
724+ tt = xp . concat ([nlist_tebd_i , nlist_tebd_j ], axis = - 1 )
711725 # nfnl x nt_i x nt_j x ng
712726 gg_t = self .cal_g_strip (tt , 0 )
713727 if self .smooth :
714728 gg_t = (
715729 gg_t
716- * sw .reshape (nf * nloc , self .nnei , 1 , 1 )
717- * sw .reshape (nf * nloc , 1 , self .nnei , 1 )
730+ * xp .reshape (sw , ( nf * nloc , self .nnei , 1 , 1 ) )
731+ * xp .reshape (sw , ( nf * nloc , 1 , self .nnei , 1 ) )
718732 )
719733 # nfnl x nt_i x nt_j x ng
720734 gg = gg_s * gg_t + gg_s
721735 else :
722736 raise NotImplementedError
723737
724738 # nfnl x ng
725- res_ij = np .einsum ("ijk,ijkm->im" , env_ij , gg )
739+ # res_ij = np.einsum("ijk,ijkm->im", env_ij, gg)
740+ res_ij = xp .sum (env_ij [:, :, :, None ] * gg [:, :, :, :], axis = (1 , 2 ))
726741 res_ij = res_ij * (1.0 / float (self .nnei ) / float (self .nnei ))
727742 # nf x nl x ng
728- result = res_ij .reshape (nf , nloc , self .filter_neuron [- 1 ]).astype (
729- GLOBAL_NP_FLOAT_PRECISION
730- )
743+ result = xp .reshape (res_ij , (nf , nloc , self .filter_neuron [- 1 ]))
744+ result = xp .astype (result , get_xp_precision (xp , "global" ))
731745 return (
732746 result ,
733747 None ,
@@ -743,3 +757,61 @@ def has_message_passing(self) -> bool:
743757 def need_sorted_nlist_for_lower (self ) -> bool :
744758 """Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""
745759 return False
760+
761+ def serialize (self ) -> dict :
762+ """Serialize the descriptor to dict."""
763+ obj = self
764+ data = {
765+ "@class" : "Descriptor" ,
766+ "type" : "se_e3_tebd" ,
767+ "@version" : 1 ,
768+ "rcut" : obj .rcut ,
769+ "rcut_smth" : obj .rcut_smth ,
770+ "sel" : obj .sel ,
771+ "ntypes" : obj .ntypes ,
772+ "neuron" : obj .neuron ,
773+ "tebd_dim" : obj .tebd_dim ,
774+ "tebd_input_mode" : obj .tebd_input_mode ,
775+ "set_davg_zero" : obj .set_davg_zero ,
776+ "activation_function" : obj .activation_function ,
777+ "resnet_dt" : obj .resnet_dt ,
778+ # make deterministic
779+ "precision" : np .dtype (PRECISION_DICT [obj .precision ]).name ,
780+ "embeddings" : obj .embeddings .serialize (),
781+ "env_mat" : obj .env_mat .serialize (),
782+ "exclude_types" : obj .exclude_types ,
783+ "env_protection" : obj .env_protection ,
784+ "smooth" : obj .smooth ,
785+ "@variables" : {
786+ "davg" : to_numpy_array (obj ["davg" ]),
787+ "dstd" : to_numpy_array (obj ["dstd" ]),
788+ },
789+ }
790+ if obj .tebd_input_mode in ["strip" ]:
791+ data .update ({"embeddings_strip" : obj .embeddings_strip .serialize ()})
792+ return data
793+
794+ @classmethod
795+ def deserialize (cls , data : dict ) -> "DescrptSeTTebd" :
796+ """Deserialize from dict."""
797+ data = data .copy ()
798+ check_version_compatibility (data .pop ("@version" ), 1 , 1 )
799+ data .pop ("@class" )
800+ data .pop ("type" )
801+ variables = data .pop ("@variables" )
802+ embeddings = data .pop ("embeddings" )
803+ env_mat = data .pop ("env_mat" )
804+ tebd_input_mode = data ["tebd_input_mode" ]
805+ if tebd_input_mode in ["strip" ]:
806+ embeddings_strip = data .pop ("embeddings_strip" )
807+ else :
808+ embeddings_strip = None
809+ se_ttebd = cls (** data )
810+
811+ se_ttebd ["davg" ] = variables ["davg" ]
812+ se_ttebd ["dstd" ] = variables ["dstd" ]
813+ se_ttebd .embeddings = NetworkCollection .deserialize (embeddings )
814+ if tebd_input_mode in ["strip" ]:
815+ se_ttebd .embeddings_strip = NetworkCollection .deserialize (embeddings_strip )
816+
817+ return se_ttebd
0 commit comments