Skip to content

Commit 835d6e5

Browse files
iProzdnahsopre-commit-ci[bot]
authored
Change the smooth method in se_atten descriptor (#2755)
This PR edits the smooth method of se_atten descriptor when using stripped type embedding: 1. Change the structure in stripped type embedding and compression ops. 2. Add smooth factor on the network output of type embedding. 3. Add a new descriptor se_atten_v2 to include these changes while keeping se_atten in old behaviors. --------- Co-authored-by: nahso <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a397f0c commit 835d6e5

File tree

27 files changed

+596
-396
lines changed

27 files changed

+596
-396
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ A full [document](doc/train/train-input-auto.rst) on options in the training inp
102102
- [Descriptor `"se_e2_r"`](doc/model/train-se-e2-r.md)
103103
- [Descriptor `"se_e3"`](doc/model/train-se-e3.md)
104104
- [Descriptor `"se_atten"`](doc/model/train-se-atten.md)
105+
- [Descriptor `"se_atten_v2"`](doc/model/train-se-atten.md#descriptor-se_atten_v2)
105106
- [Descriptor `"hybrid"`](doc/model/train-hybrid.md)
106107
- [Descriptor `sel`](doc/model/sel.md)
107108
- [Fit energy](doc/model/train-energy.md)

deepmd/descriptor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
from .se_atten import (
2525
DescrptSeAtten,
2626
)
27+
from .se_atten_v2 import (
28+
DescrptSeAttenV2,
29+
)
2730
from .se_r import (
2831
DescrptSeR,
2932
)
@@ -41,6 +44,7 @@
4144
"DescrptSeAEfLower",
4245
"DescrptSeAMask",
4346
"DescrptSeAtten",
47+
"DescrptSeAttenV2",
4448
"DescrptSeR",
4549
"DescrptSeT",
4650
]

deepmd/descriptor/se_atten.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,13 @@ class DescrptSeAtten(DescrptSeA):
108108
Whether to mask the diagonal in the attention weights.
109109
multi_task
110110
If the model has multi fitting nets to train.
111+
stripped_type_embedding
112+
Whether to strip the type embedding into a separated embedding network.
113+
Default value will be True in `se_atten_v2` descriptor.
114+
smooth_type_embdding
115+
When using stripped type embedding, whether to dot smooth factor on the network output of type embedding
116+
to keep the network smooth, instead of setting `set_davg_zero` to be True.
117+
Default value will be True in `se_atten_v2` descriptor.
111118
"""
112119

113120
def __init__(
@@ -133,9 +140,10 @@ def __init__(
133140
attn_mask: bool = False,
134141
multi_task: bool = False,
135142
stripped_type_embedding: bool = False,
143+
smooth_type_embdding: bool = False,
136144
**kwargs,
137145
) -> None:
138-
if not set_davg_zero:
146+
if not set_davg_zero and not (stripped_type_embedding and smooth_type_embdding):
139147
warnings.warn(
140148
"Set 'set_davg_zero' False in descriptor 'se_atten' "
141149
"may cause unexpected incontinuity during model inference!"
@@ -166,6 +174,7 @@ def __init__(
166174
"2"
167175
), "se_atten only support tensorflow version 2.0 or higher."
168176
self.stripped_type_embedding = stripped_type_embedding
177+
self.smooth = smooth_type_embdding
169178
self.ntypes = ntypes
170179
self.att_n = attn
171180
self.attn_layer = attn_layer
@@ -607,6 +616,7 @@ def build(
607616
sel_a=self.sel_all_a,
608617
sel_r=self.sel_all_r,
609618
)
619+
610620
self.nei_type_vec = tf.reshape(self.nei_type_vec, [-1])
611621
self.nmask = tf.cast(
612622
tf.reshape(self.nmask, [-1, 1, self.sel_all_a[0]]),
@@ -625,6 +635,41 @@ def build(
625635
tf.slice(atype, [0, 0], [-1, natoms[0]]), [-1]
626636
) ## lammps will have error without this
627637
self._identity_tensors(suffix=suffix)
638+
if self.smooth:
639+
self.sliced_avg = tf.reshape(
640+
tf.slice(
641+
tf.reshape(self.t_avg, [self.ntypes, -1, 4]), [0, 0, 0], [-1, 1, 1]
642+
),
643+
[self.ntypes, 1],
644+
)
645+
self.sliced_std = tf.reshape(
646+
tf.slice(
647+
tf.reshape(self.t_std, [self.ntypes, -1, 4]), [0, 0, 0], [-1, 1, 1]
648+
),
649+
[self.ntypes, 1],
650+
)
651+
self.avg_looked_up = tf.reshape(
652+
tf.nn.embedding_lookup(self.sliced_avg, self.atype_nloc),
653+
[-1, natoms[0], 1],
654+
)
655+
self.std_looked_up = tf.reshape(
656+
tf.nn.embedding_lookup(self.sliced_std, self.atype_nloc),
657+
[-1, natoms[0], 1],
658+
)
659+
self.recovered_r = (
660+
tf.reshape(
661+
tf.slice(tf.reshape(self.descrpt, [-1, 4]), [0, 0], [-1, 1]),
662+
[-1, natoms[0], self.sel_all_a[0]],
663+
)
664+
* self.std_looked_up
665+
+ self.avg_looked_up
666+
)
667+
uu = 1 - self.rcut_r_smth * self.recovered_r
668+
self.recovered_switch = -uu * uu * uu + 1
669+
self.recovered_switch = tf.clip_by_value(self.recovered_switch, 0.0, 1.0)
670+
self.recovered_switch = tf.cast(
671+
self.recovered_switch, self.filter_precision
672+
)
628673

629674
self.dout, self.qmat = self._pass_filter(
630675
self.descrpt_reshape,
@@ -1146,9 +1191,10 @@ def _filter_lower(
11461191
two_embd = tf.nn.embedding_lookup(
11471192
embedding_of_two_side_type_embedding, index_of_two_side
11481193
)
1149-
1194+
if self.smooth:
1195+
two_embd = two_embd * tf.reshape(self.recovered_switch, [-1, 1])
11501196
if not self.compress:
1151-
xyz_scatter = xyz_scatter * two_embd + two_embd
1197+
xyz_scatter = xyz_scatter * two_embd + xyz_scatter
11521198
else:
11531199
return op_module.tabulate_fusion_se_atten(
11541200
tf.cast(self.table.data[net], self.filter_precision),

deepmd/descriptor/se_atten_v2.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import logging
3+
from typing import (
4+
List,
5+
Optional,
6+
)
7+
8+
from .descriptor import (
9+
Descriptor,
10+
)
11+
from .se_atten import (
12+
DescrptSeAtten,
13+
)
14+
15+
log = logging.getLogger(__name__)
16+
17+
18+
@Descriptor.register("se_atten_v2")
19+
class DescrptSeAttenV2(DescrptSeAtten):
20+
r"""Smooth version 2.0 descriptor with attention.
21+
22+
Parameters
23+
----------
24+
rcut
25+
The cut-off radius :math:`r_c`
26+
rcut_smth
27+
From where the environment matrix should be smoothed :math:`r_s`
28+
sel : list[str]
29+
sel[i] specifies the maxmum number of type i atoms in the cut-off radius
30+
neuron : list[int]
31+
Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}`
32+
axis_neuron
33+
Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix)
34+
resnet_dt
35+
Time-step `dt` in the resnet construction:
36+
y = x + dt * \phi (Wx + b)
37+
trainable
38+
If the weights of embedding net are trainable.
39+
seed
40+
Random seed for initializing the network parameters.
41+
type_one_side
42+
Try to build N_types embedding nets. Otherwise, building N_types^2 embedding nets
43+
exclude_types : List[List[int]]
44+
The excluded pairs of types which have no interaction with each other.
45+
For example, `[[0, 1]]` means no interaction between type 0 and type 1.
46+
set_davg_zero
47+
Set the shift of embedding net input to zero.
48+
activation_function
49+
The activation function in the embedding net. Supported options are |ACTIVATION_FN|
50+
precision
51+
The precision of the embedding net parameters. Supported options are |PRECISION|
52+
uniform_seed
53+
Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed
54+
attn
55+
The length of hidden vector during scale-dot attention computation.
56+
attn_layer
57+
The number of layers in attention mechanism.
58+
attn_dotr
59+
Whether to dot the relative coordinates on the attention weights as a gated scheme.
60+
attn_mask
61+
Whether to mask the diagonal in the attention weights.
62+
multi_task
63+
If the model has multi fitting nets to train.
64+
"""
65+
66+
def __init__(
67+
self,
68+
rcut: float,
69+
rcut_smth: float,
70+
sel: int,
71+
ntypes: int,
72+
neuron: List[int] = [24, 48, 96],
73+
axis_neuron: int = 8,
74+
resnet_dt: bool = False,
75+
trainable: bool = True,
76+
seed: Optional[int] = None,
77+
type_one_side: bool = True,
78+
set_davg_zero: bool = False,
79+
exclude_types: List[List[int]] = [],
80+
activation_function: str = "tanh",
81+
precision: str = "default",
82+
uniform_seed: bool = False,
83+
attn: int = 128,
84+
attn_layer: int = 2,
85+
attn_dotr: bool = True,
86+
attn_mask: bool = False,
87+
multi_task: bool = False,
88+
**kwargs,
89+
) -> None:
90+
DescrptSeAtten.__init__(
91+
self,
92+
rcut,
93+
rcut_smth,
94+
sel,
95+
ntypes,
96+
neuron=neuron,
97+
axis_neuron=axis_neuron,
98+
resnet_dt=resnet_dt,
99+
trainable=trainable,
100+
seed=seed,
101+
type_one_side=type_one_side,
102+
set_davg_zero=set_davg_zero,
103+
exclude_types=exclude_types,
104+
activation_function=activation_function,
105+
precision=precision,
106+
uniform_seed=uniform_seed,
107+
attn=attn,
108+
attn_layer=attn_layer,
109+
attn_dotr=attn_dotr,
110+
attn_mask=attn_mask,
111+
multi_task=multi_task,
112+
stripped_type_embedding=True,
113+
smooth_type_embdding=True,
114+
**kwargs,
115+
)

deepmd/entrypoints/train.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,15 @@ def update_one_sel(jdata, descriptor):
476476
if descriptor["type"] == "loc_frame":
477477
return descriptor
478478
rcut = descriptor["rcut"]
479-
tmp_sel = get_sel(jdata, rcut, one_type=descriptor["type"] in ("se_atten",))
479+
tmp_sel = get_sel(
480+
jdata,
481+
rcut,
482+
one_type=descriptor["type"]
483+
in (
484+
"se_atten",
485+
"se_atten_v2",
486+
),
487+
)
480488
sel = descriptor["sel"]
481489
if isinstance(sel, int):
482490
# convert to list and finnally convert back to int
@@ -495,7 +503,10 @@ def update_one_sel(jdata, descriptor):
495503
"not less than %d, but you set it to %d. The accuracy"
496504
" of your model may get worse." % (ii, tt, dd)
497505
)
498-
if descriptor["type"] in ("se_atten",):
506+
if descriptor["type"] in (
507+
"se_atten",
508+
"se_atten_v2",
509+
):
499510
descriptor["sel"] = sel = sum(sel)
500511
return descriptor
501512

deepmd/utils/argcheck.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -333,9 +333,7 @@ def descrpt_hybrid_args():
333333
]
334334

335335

336-
@descrpt_args_plugin.register("se_atten")
337-
def descrpt_se_atten_args():
338-
doc_stripped_type_embedding = "Whether to strip the type embedding into a separated embedding network. Setting it to `False` will fall back to the previous version of `se_atten` which is non-compressible."
336+
def descrpt_se_atten_common_args():
339337
doc_sel = 'This parameter set the number of selected neighbors. Note that this parameter is a little different from that in other descriptors. Instead of separating each type of atoms, only the summation matters. And this number is highly related with the efficiency, thus one should not make it too large. Usually 200 or less is enough, far away from the GPU limitation 4096. It can be:\n\n\
340338
- `int`. The maximum number of neighbor atoms to be considered. We recommend it to be less than 200. \n\n\
341339
- `List[int]`. The length of the list should be the same as the number of atom types in the system. `sel[i]` gives the selected number of type-i neighbors. Only the summation of `sel[i]` matters, and it is recommended to be less than 200.\
@@ -350,21 +348,13 @@ def descrpt_se_atten_args():
350348
doc_precision = f"The precision of the embedding net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision."
351349
doc_trainable = "If the parameters in the embedding net is trainable"
352350
doc_seed = "Random seed for parameter initialization"
353-
doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `se_atten` descriptor or `atom_ener` in the energy fitting is used"
354351
doc_exclude_types = "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1."
355352
doc_attn = "The length of hidden vectors in attention layers"
356353
doc_attn_layer = "The number of attention layers. Note that model compression of `se_atten` is only enabled when attn_layer==0 and stripped_type_embedding is True"
357354
doc_attn_dotr = "Whether to do dot product with the normalized relative coordinates"
358355
doc_attn_mask = "Whether to do mask on the diagonal in the attention matrix"
359356

360357
return [
361-
Argument(
362-
"stripped_type_embedding",
363-
bool,
364-
optional=True,
365-
default=False,
366-
doc=doc_stripped_type_embedding,
367-
),
368358
Argument("sel", [int, list, str], optional=True, default="auto", doc=doc_sel),
369359
Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut),
370360
Argument("rcut_smth", float, optional=True, default=0.5, doc=doc_rcut_smth),
@@ -394,16 +384,51 @@ def descrpt_se_atten_args():
394384
Argument(
395385
"exclude_types", list, optional=True, default=[], doc=doc_exclude_types
396386
),
397-
Argument(
398-
"set_davg_zero", bool, optional=True, default=True, doc=doc_set_davg_zero
399-
),
400387
Argument("attn", int, optional=True, default=128, doc=doc_attn),
401388
Argument("attn_layer", int, optional=True, default=2, doc=doc_attn_layer),
402389
Argument("attn_dotr", bool, optional=True, default=True, doc=doc_attn_dotr),
403390
Argument("attn_mask", bool, optional=True, default=False, doc=doc_attn_mask),
404391
]
405392

406393

394+
@descrpt_args_plugin.register("se_atten")
395+
def descrpt_se_atten_args():
396+
doc_stripped_type_embedding = "Whether to strip the type embedding into a separated embedding network. Setting it to `False` will fall back to the previous version of `se_atten` which is non-compressible."
397+
doc_smooth_type_embdding = "When using stripped type embedding, whether to dot smooth factor on the network output of type embedding to keep the network smooth, instead of setting `set_davg_zero` to be True."
398+
doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `se_atten` descriptor or `atom_ener` in the energy fitting is used"
399+
400+
return descrpt_se_atten_common_args() + [
401+
Argument(
402+
"stripped_type_embedding",
403+
bool,
404+
optional=True,
405+
default=False,
406+
doc=doc_stripped_type_embedding,
407+
),
408+
Argument(
409+
"smooth_type_embdding",
410+
bool,
411+
optional=True,
412+
default=False,
413+
doc=doc_smooth_type_embdding,
414+
),
415+
Argument(
416+
"set_davg_zero", bool, optional=True, default=True, doc=doc_set_davg_zero
417+
),
418+
]
419+
420+
421+
@descrpt_args_plugin.register("se_atten_v2")
422+
def descrpt_se_atten_v2_args():
423+
doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `se_atten` descriptor or `atom_ener` in the energy fitting is used"
424+
425+
return descrpt_se_atten_common_args() + [
426+
Argument(
427+
"set_davg_zero", bool, optional=True, default=False, doc=doc_set_davg_zero
428+
),
429+
]
430+
431+
407432
@descrpt_args_plugin.register("se_a_mask")
408433
def descrpt_se_a_mask_args():
409434
doc_sel = 'This parameter sets the number of selected neighbors for each type of atom. It can be:\n\n\
@@ -459,13 +484,15 @@ def descrpt_variant_type_args(exclude_hybrid: bool = False) -> Variant:
459484
link_se_a_tpe = make_link("se_a_tpe", "model/descriptor[se_a_tpe]")
460485
link_hybrid = make_link("hybrid", "model/descriptor[hybrid]")
461486
link_se_atten = make_link("se_atten", "model/descriptor[se_atten]")
487+
link_se_atten_v2 = make_link("se_atten_v2", "model/descriptor[se_atten_v2]")
462488
doc_descrpt_type = "The type of the descritpor. See explanation below. \n\n\
463489
- `loc_frame`: Defines a local frame at each atom, and the compute the descriptor as local coordinates under this frame.\n\n\
464490
- `se_e2_a`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor.\n\n\
465491
- `se_e2_r`: Used by the smooth edition of Deep Potential. Only the distance between atoms is used to construct the descriptor.\n\n\
466492
- `se_e3`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Three-body embedding will be used by this descriptor.\n\n\
467493
- `se_a_tpe`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Type embedding will be used by this descriptor.\n\n\
468494
- `se_atten`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Attention mechanism will be used by this descriptor.\n\n\
495+
- `se_atten_v2`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Attention mechanism with new modifications will be used by this descriptor.\n\n\
469496
- `se_a_mask`: Used by the smooth edition of Deep Potential. It can accept a variable number of atoms in a frame (Non-PBC system). *aparam* are required as an indicator matrix for the real/virtual sign of input atoms. \n\n\
470497
- `hybrid`: Concatenate of a list of descriptors as a new descriptor."
471498

deepmd/utils/finetune.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,11 @@ def replace_model_params_with_pretrained_model(
4242

4343
# Check the model type
4444
assert pretrained_jdata["model"]["descriptor"]["type"] in [
45-
"se_atten"
45+
"se_atten",
46+
"se_atten_v2",
4647
] and pretrained_jdata["model"]["fitting_net"]["type"] in [
4748
"ener"
48-
], "The finetune process only supports models pretrained with 'se_atten' descriptor and 'ener' fitting_net!"
49+
], "The finetune process only supports models pretrained with 'se_atten' or 'se_atten_v2' descriptor and 'ener' fitting_net!"
4950

5051
# Check the type map
5152
pretrained_type_map = pretrained_jdata["model"]["type_map"]

doc/credits.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Cite DeePMD-kit and methods
4242

4343
Wang_NuclFusion_2022_v62_p126013
4444

45-
- If attention-based descriptor (`se_atten`) is used,
45+
- If attention-based descriptor (`se_atten`, `se_atten_v2`) is used,
4646

4747
.. bibliography::
4848
:filter: False

0 commit comments

Comments
 (0)