@@ -108,6 +108,13 @@ class DescrptSeAtten(DescrptSeA):
108108 Whether to mask the diagonal in the attention weights.
109109 multi_task
110110 If the model has multi fitting nets to train.
111+ stripped_type_embedding
112+ Whether to strip the type embedding into a separated embedding network.
113+ Default value will be True in `se_atten_v2` descriptor.
114+ smooth_type_embdding
115+ When using stripped type embedding, whether to dot smooth factor on the network output of type embedding
116+ to keep the network smooth, instead of setting `set_davg_zero` to be True.
117+ Default value will be True in `se_atten_v2` descriptor.
111118 """
112119
113120 def __init__ (
@@ -133,9 +140,10 @@ def __init__(
133140 attn_mask : bool = False ,
134141 multi_task : bool = False ,
135142 stripped_type_embedding : bool = False ,
143+ smooth_type_embdding : bool = False ,
136144 ** kwargs ,
137145 ) -> None :
138- if not set_davg_zero :
146+ if not set_davg_zero and not ( stripped_type_embedding and smooth_type_embdding ) :
139147 warnings .warn (
140148 "Set 'set_davg_zero' False in descriptor 'se_atten' "
141149 "may cause unexpected incontinuity during model inference!"
@@ -166,6 +174,7 @@ def __init__(
166174 "2"
167175 ), "se_atten only support tensorflow version 2.0 or higher."
168176 self .stripped_type_embedding = stripped_type_embedding
177+ self .smooth = smooth_type_embdding
169178 self .ntypes = ntypes
170179 self .att_n = attn
171180 self .attn_layer = attn_layer
@@ -607,6 +616,7 @@ def build(
607616 sel_a = self .sel_all_a ,
608617 sel_r = self .sel_all_r ,
609618 )
619+
610620 self .nei_type_vec = tf .reshape (self .nei_type_vec , [- 1 ])
611621 self .nmask = tf .cast (
612622 tf .reshape (self .nmask , [- 1 , 1 , self .sel_all_a [0 ]]),
@@ -625,6 +635,41 @@ def build(
625635 tf .slice (atype , [0 , 0 ], [- 1 , natoms [0 ]]), [- 1 ]
626636 ) ## lammps will have error without this
627637 self ._identity_tensors (suffix = suffix )
638+ if self .smooth :
639+ self .sliced_avg = tf .reshape (
640+ tf .slice (
641+ tf .reshape (self .t_avg , [self .ntypes , - 1 , 4 ]), [0 , 0 , 0 ], [- 1 , 1 , 1 ]
642+ ),
643+ [self .ntypes , 1 ],
644+ )
645+ self .sliced_std = tf .reshape (
646+ tf .slice (
647+ tf .reshape (self .t_std , [self .ntypes , - 1 , 4 ]), [0 , 0 , 0 ], [- 1 , 1 , 1 ]
648+ ),
649+ [self .ntypes , 1 ],
650+ )
651+ self .avg_looked_up = tf .reshape (
652+ tf .nn .embedding_lookup (self .sliced_avg , self .atype_nloc ),
653+ [- 1 , natoms [0 ], 1 ],
654+ )
655+ self .std_looked_up = tf .reshape (
656+ tf .nn .embedding_lookup (self .sliced_std , self .atype_nloc ),
657+ [- 1 , natoms [0 ], 1 ],
658+ )
659+ self .recovered_r = (
660+ tf .reshape (
661+ tf .slice (tf .reshape (self .descrpt , [- 1 , 4 ]), [0 , 0 ], [- 1 , 1 ]),
662+ [- 1 , natoms [0 ], self .sel_all_a [0 ]],
663+ )
664+ * self .std_looked_up
665+ + self .avg_looked_up
666+ )
667+ uu = 1 - self .rcut_r_smth * self .recovered_r
668+ self .recovered_switch = - uu * uu * uu + 1
669+ self .recovered_switch = tf .clip_by_value (self .recovered_switch , 0.0 , 1.0 )
670+ self .recovered_switch = tf .cast (
671+ self .recovered_switch , self .filter_precision
672+ )
628673
629674 self .dout , self .qmat = self ._pass_filter (
630675 self .descrpt_reshape ,
@@ -1146,9 +1191,10 @@ def _filter_lower(
11461191 two_embd = tf .nn .embedding_lookup (
11471192 embedding_of_two_side_type_embedding , index_of_two_side
11481193 )
1149-
1194+ if self .smooth :
1195+ two_embd = two_embd * tf .reshape (self .recovered_switch , [- 1 , 1 ])
11501196 if not self .compress :
1151- xyz_scatter = xyz_scatter * two_embd + two_embd
1197+ xyz_scatter = xyz_scatter * two_embd + xyz_scatter
11521198 else :
11531199 return op_module .tabulate_fusion_se_atten (
11541200 tf .cast (self .table .data [net ], self .filter_precision ),
0 commit comments