Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 90 additions & 12 deletions deepmd_pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
from deepmd_pt.utils.nlist import build_multiple_neighbor_list

from .se_atten import analyze_descrpt
from .se_atten import DescrptBlockSeAtten
from .se_atten import DescrptBlockSeAtten, NeighborGatedAttention
from deepmd_pt.model.network.mlp import EmbdLayer, NetworkCollection
from deepmd_utils.model_format import (
EnvMat as DPEnvMat,
)

@Descriptor.register("dpa1")
@Descriptor.register("se_atten")
Expand All @@ -41,14 +45,16 @@ def __init__(
post_ln=True,
ffn=False,
ffn_embed_dim=1024,
activation="tanh",
activation_function="tanh",
precision: str = "float64",
resnet_dt: bool = False,
scaling_factor=1.0,
head_num=1,
normalize=True,
temperature=None,
return_rot=False,
concat_output_tebd: bool = True,
type: Optional[str] = None,
old_impl: bool = False,
**kwargs,
):
super(DescrptDPA1, self).__init__()
del type
Expand All @@ -63,17 +69,22 @@ def __init__(
attn_layer=attn_layer,
attn_dotr=attn_dotr,
attn_mask=attn_mask,
post_ln=post_ln,
ffn=ffn,
ffn_embed_dim=ffn_embed_dim,
activation=activation,
activation_function=activation_function,
precision=precision,
resnet_dt=resnet_dt,
scaling_factor=scaling_factor,
head_num=head_num,
normalize=normalize,
temperature=temperature,
return_rot=return_rot,
old_impl=old_impl,
**kwargs,
)
self.type_embedding = TypeEmbedNet(ntypes, tebd_dim)
self.type_embedding_old = None
self.type_embedding = None
self.old_impl = old_impl
if self.old_impl:
self.type_embedding_old = TypeEmbedNet(ntypes, tebd_dim)
else:
self.type_embedding = EmbdLayer(ntypes, tebd_dim, padding=True, precision=precision)
self.tebd_dim = tebd_dim
self.concat_output_tebd = concat_output_tebd

Expand Down Expand Up @@ -147,7 +158,12 @@ def forward(
del mapping
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
g1_ext = self.type_embedding(extended_atype)
if self.old_impl:
assert self.type_embedding_old is not None
g1_ext = self.type_embedding_old(extended_atype)
else:
assert self.type_embedding is not None
g1_ext = self.type_embedding(extended_atype)
g1_inp = g1_ext[:,:nloc,:]
g1, env_mat, diff, rot_mat, sw = self.se_atten(
nlist,
Expand All @@ -158,5 +174,67 @@ def forward(
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
return g1, env_mat, diff, rot_mat, sw

def set_stat_mean_and_stddev(
self,
mean: torch.Tensor,
stddev: torch.Tensor,
)->None:
self.se_atten.mean = mean
self.se_atten.stddev = stddev

def serialize(self) -> dict:
obj = self.se_atten
return {
"rcut": obj.rcut,
"rcut_smth": obj.rcut_smth,
"sel": obj.sel,
"ntypes": obj.ntypes,
"neuron": obj.neuron,
"axis_neuron": obj.axis_neuron,
"tebd_dim": obj.tebd_dim,
"tebd_input_mode": obj.tebd_input_mode,
"set_davg_zero": obj.set_davg_zero,
"attn": obj.attn_dim,
"attn_layer": obj.attn_layer,
"attn_dotr": obj.attn_dotr,
"attn_mask": obj.attn_mask,
"activation_function": obj.activation_function,
"precision": obj.precision,
"resnet_dt": obj.resnet_dt,
"scaling_factor": obj.scaling_factor,
"normalize": obj.normalize,
"temperature": obj.temperature,
"concat_output_tebd": self.concat_output_tebd,
"embeddings": obj.filter_layers.serialize(),
"attention_layers": obj.dpa1_attention.serialize(),
"env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(),
"type_embedding": self.type_embedding.serialize(),
"@variables": {
"davg": obj["davg"].detach().cpu().numpy(),
"dstd": obj["dstd"].detach().cpu().numpy(),
},
## to be updated when the options are supported.
"trainable": True,
"type_one_side": True,
"exclude_types": [],
"spin": None,
}

@classmethod
def deserialize(cls, data: dict) -> "DescrptDPA1":
variables = data.pop("@variables")
embeddings = data.pop("embeddings")
type_embedding = data.pop("type_embedding")
attention_layers = data.pop("attention_layers")
env_mat = data.pop("env_mat")
obj = cls(**data)
t_cvt = lambda xx: torch.tensor(xx, dtype=obj.se_atten.prec, device=env.DEVICE)
obj.type_embedding = EmbdLayer.deserialize(type_embedding)
obj.se_atten["davg"] = t_cvt(variables["davg"])
obj.se_atten["dstd"] = t_cvt(variables["dstd"])
obj.se_atten.filter_layers = NetworkCollection.deserialize(embeddings)
obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize(attention_layers)
return obj


2 changes: 1 addition & 1 deletion deepmd_pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def forward(
nfnl = dmatrix.shape[0]
# pre-allocate a shape to pass jit
xyz_scatter = torch.zeros([nfnl, 4, self.filter_neuron[-1]], dtype=self.prec, device=env.DEVICE)
for ii,ll in enumerate(self.filter_layers.networks):
for ii,ll in enumerate(self.filter_layers._networks):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better providing a method to access the data of a class object.

# nfnl x nt x 4
rr = dmatrix[:, self.sec[ii]:self.sec[ii+1], :]
ss = rr[:,:,:1]
Expand Down
Loading