Skip to content

Commit 982dc55

Browse files
iProzdOutisLi
authored andcommitted
update dynamic sel
1 parent 667af15 commit 982dc55

File tree

4 files changed

+237
-108
lines changed

4 files changed

+237
-108
lines changed

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 29 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -210,12 +210,8 @@ def __init__(
210210
self.n_a_compress_dim = n_dim
211211
else:
212212
# angle + a_dim/c + a_dim/2c * 2 * e_rate
213-
self.angle_dim += (1 + self.a_compress_e_rate) * (
214-
self.a_dim // self.a_compress_rate
215-
)
216-
self.e_a_compress_dim = (
217-
self.a_dim // (2 * self.a_compress_rate) * self.a_compress_e_rate
218-
)
213+
self.angle_dim += (1 + self.a_compress_e_rate) * (self.a_dim // self.a_compress_rate)
214+
self.e_a_compress_dim = self.a_dim // (2 * self.a_compress_rate) * self.a_compress_e_rate
219215
self.n_a_compress_dim = self.a_dim // self.a_compress_rate
220216
if not self.a_compress_use_split:
221217
self.a_compress_n_linear = MLPLayer(
@@ -331,10 +327,7 @@ def _cal_hg(
331327
# nb x nloc x nnei x e_dim
332328
edge_ebd = _apply_nlist_mask(edge_ebd, nlist_mask)
333329
edge_ebd = _apply_switch(edge_ebd, sw)
334-
invnnei = torch.rsqrt(
335-
float(nnei)
336-
* torch.ones((nb, nloc, 1, 1), dtype=edge_ebd.dtype, device=edge_ebd.device)
337-
)
330+
invnnei = torch.rsqrt(float(nnei) * torch.ones((nb, nloc, 1, 1), dtype=edge_ebd.dtype, device=edge_ebd.device))
338331
# nb x nloc x 3 x e_dim
339332
h2g2 = torch.matmul(torch.transpose(h2, -1, -2), edge_ebd) * invnnei
340333
return h2g2
@@ -382,16 +375,9 @@ def _cal_hg_dynamic(
382375
# n_edge x e_dim
383376
flat_edge_ebd = flat_edge_ebd * flat_sw.unsqueeze(-1)
384377
# n_edge x 3 x e_dim
385-
flat_h2g2 = (flat_h2.unsqueeze(-1) * flat_edge_ebd.unsqueeze(-2)).reshape(
386-
-1, 3 * e_dim
387-
)
378+
flat_h2g2 = (flat_h2.unsqueeze(-1) * flat_edge_ebd.unsqueeze(-2)).reshape(-1, 3 * e_dim)
388379
# nf x nloc x 3 x e_dim
389-
h2g2 = (
390-
aggregate(flat_h2g2, owner, average=False, num_owner=num_owner).reshape(
391-
nb, nloc, 3, e_dim
392-
)
393-
* scale_factor
394-
)
380+
h2g2 = aggregate(flat_h2g2, owner, average=False, num_owner=num_owner).reshape(nb, nloc, 3, e_dim) * scale_factor
395381
return h2g2
396382

397383
@staticmethod
@@ -544,9 +530,7 @@ def optim_angle_update(
544530
node_dim = node_ebd.shape[-1]
545531
edge_dim = edge_ebd.shape[-1]
546532
# angle_dim, node_dim, edge_dim, edge_dim
547-
sub_angle, sub_node, sub_edge_ik, sub_edge_ij = torch.split(
548-
matrix, [angle_dim, node_dim, edge_dim, edge_dim]
549-
)
533+
sub_angle, sub_node, sub_edge_ik, sub_edge_ij = torch.split(matrix, [angle_dim, node_dim, edge_dim, edge_dim])
550534

551535
# nf * nloc * a_sel * a_sel * angle_dim
552536
sub_angle_update = torch.matmul(angle_ebd, sub_angle)
@@ -557,11 +541,7 @@ def optim_angle_update(
557541
sub_edge_update_ij = torch.matmul(edge_ebd, sub_edge_ij)
558542

559543
result_update = (
560-
bias
561-
+ sub_node_update.unsqueeze(2).unsqueeze(3)
562-
+ sub_edge_update_ik.unsqueeze(2)
563-
+ sub_edge_update_ij.unsqueeze(3)
564-
+ sub_angle_update
544+
bias + sub_node_update.unsqueeze(2).unsqueeze(3) + sub_edge_update_ik.unsqueeze(2) + sub_edge_update_ij.unsqueeze(3) + sub_angle_update
565545
)
566546
return result_update
567547

@@ -585,19 +565,15 @@ def optim_angle_update_dynamic(
585565
edge_dim = flat_edge_ebd.shape[-1]
586566
angle_dim = flat_angle_ebd.shape[-1]
587567
# angle_dim, node_dim, edge_dim, edge_dim
588-
sub_angle, sub_node, sub_edge_ik, sub_edge_ij = torch.split(
589-
matrix, [angle_dim, node_dim, edge_dim, edge_dim]
590-
)
568+
sub_angle, sub_node, sub_edge_ik, sub_edge_ij = torch.split(matrix, [angle_dim, node_dim, edge_dim, edge_dim])
591569

592570
# n_angle * angle_dim
593571
sub_angle_update = torch.matmul(flat_angle_ebd, sub_angle)
594572

595573
# nf * nloc * angle_dim
596574
sub_node_update = torch.matmul(node_ebd, sub_node)
597575
# n_angle * angle_dim
598-
sub_node_update = torch.index_select(
599-
sub_node_update.reshape(nf * nloc, sub_node_update.shape[-1]), 0, n2a_index
600-
)
576+
sub_node_update = torch.index_select(sub_node_update.reshape(nf * nloc, sub_node_update.shape[-1]), 0, n2a_index)
601577

602578
# n_edge * angle_dim
603579
sub_edge_update_ik = torch.matmul(flat_edge_ebd, sub_edge_ik)
@@ -606,13 +582,7 @@ def optim_angle_update_dynamic(
606582
sub_edge_update_ik = torch.index_select(sub_edge_update_ik, 0, eik2a_index)
607583
sub_edge_update_ij = torch.index_select(sub_edge_update_ij, 0, eij2a_index)
608584

609-
result_update = (
610-
bias
611-
+ sub_node_update
612-
+ sub_edge_update_ik
613-
+ sub_edge_update_ij
614-
+ sub_angle_update
615-
)
585+
result_update = bias + sub_node_update + sub_edge_update_ik + sub_edge_update_ij + sub_angle_update
616586
return result_update
617587

618588
def optim_edge_update(
@@ -645,9 +615,7 @@ def optim_edge_update(
645615
# nf * nloc * nnei * node/edge_dim
646616
sub_edge_update = torch.matmul(edge_ebd, edge)
647617

648-
result_update = (
649-
bias + sub_node_update.unsqueeze(2) + sub_edge_update + sub_node_ext_update
650-
)
618+
result_update = bias + sub_node_update.unsqueeze(2) + sub_edge_update + sub_node_ext_update
651619
return result_update
652620

653621
def optim_edge_update_dynamic(
@@ -675,9 +643,7 @@ def optim_edge_update_dynamic(
675643
# nf * nloc * node/edge_dim
676644
sub_node_update = torch.matmul(node_ebd, node)
677645
# n_edge * node/edge_dim
678-
sub_node_update = torch.index_select(
679-
sub_node_update.reshape(nf * nloc, sub_node_update.shape[-1]), 0, n2e_index
680-
)
646+
sub_node_update = torch.index_select(sub_node_update.reshape(nf * nloc, sub_node_update.shape[-1]), 0, n2e_index)
681647

682648
# nf * nall * node/edge_dim
683649
sub_node_ext_update = torch.matmul(node_ebd_ext, node_ext)
@@ -757,14 +723,12 @@ def forward(
757723
nb, nloc, nnei = nlist.shape
758724
nall = node_ebd_ext.shape[1]
759725
node_ebd = node_ebd_ext[:, :nloc, :]
726+
n_edge = int(nlist_mask.sum().item()) if self.use_dynamic_sel else 0
760727
assert (nb, nloc) == node_ebd.shape[:2]
761728
if not self.use_dynamic_sel:
762729
assert (nb, nloc, nnei, 3) == h2.shape
763-
n_edge = None
764730
else:
765-
# n_edge = int(nlist_mask.sum().item())
766-
# assert (n_edge, 3) == h2.shape
767-
n_edge = h2.shape[0]
731+
assert (n_edge, 3) == h2.shape
768732
del a_nlist # may be used in the future
769733

770734
n2e_index, n_ext2e_index = edge_index[0], edge_index[1]
@@ -778,9 +742,7 @@ def forward(
778742
nei_node_ebd = (
779743
_make_nei_g1(node_ebd_ext, nlist)
780744
if not self.use_dynamic_sel
781-
else torch.index_select(
782-
node_ebd_ext.reshape(-1, self.n_dim), 0, n_ext2e_index
783-
)
745+
else torch.index_select(node_ebd_ext.reshape(-1, self.n_dim), 0, n_ext2e_index)
784746
)
785747

786748
n_update_list: list[torch.Tensor] = [node_ebd]
@@ -853,9 +815,7 @@ def forward(
853815
# n_edge x (n_dim * 2 + e_dim)
854816
edge_info = torch.cat(
855817
[
856-
torch.index_select(
857-
node_ebd.reshape(-1, self.n_dim), 0, n2e_index
858-
),
818+
torch.index_select(node_ebd.reshape(-1, self.n_dim), 0, n2e_index),
859819
nei_node_ebd,
860820
edge_ebd,
861821
],
@@ -868,9 +828,7 @@ def forward(
868828
# nb x nloc x nnei x (h * n_dim)
869829
if not self.optim_update:
870830
assert edge_info is not None
871-
node_edge_update = self.act(
872-
self.node_edge_linear(edge_info)
873-
) * sw.unsqueeze(-1)
831+
node_edge_update = self.act(self.node_edge_linear(edge_info)) * sw.unsqueeze(-1)
874832
else:
875833
node_edge_update = self.act(
876834
self.optim_edge_update(
@@ -906,9 +864,7 @@ def forward(
906864

907865
if self.n_multi_edge_message > 1:
908866
# nb x nloc x h x n_dim
909-
node_edge_update_mul_head = node_edge_update.view(
910-
nb, nloc, self.n_multi_edge_message, self.n_dim
911-
)
867+
node_edge_update_mul_head = node_edge_update.view(nb, nloc, self.n_multi_edge_message, self.n_dim)
912868
for head_index in range(self.n_multi_edge_message):
913869
n_update_list.append(node_edge_update_mul_head[..., head_index, :])
914870
else:
@@ -964,9 +920,7 @@ def forward(
964920
# nb x nloc x a_nnei x e_dim
965921
edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :]
966922
# nb x nloc x a_nnei x e_dim
967-
edge_ebd_for_angle = torch.where(
968-
a_nlist_mask.unsqueeze(-1), edge_ebd_for_angle, 0.0
969-
)
923+
edge_ebd_for_angle = torch.where(a_nlist_mask.unsqueeze(-1), edge_ebd_for_angle, 0.0)
970924
if not self.optim_update:
971925
# nb x nloc x a_nnei x a_nnei x n_dim [OR] n_angle x n_dim
972926
node_for_angle_info = (
@@ -984,24 +938,18 @@ def forward(
984938

985939
# nb x nloc x (a_nnei) x a_nnei x e_dim [OR] n_angle x e_dim
986940
edge_for_angle_k = (
987-
torch.tile(
988-
edge_ebd_for_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1)
989-
)
941+
torch.tile(edge_ebd_for_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1))
990942
if not self.use_dynamic_sel
991943
else torch.index_select(edge_ebd_for_angle, 0, eik2a_index)
992944
)
993945
# nb x nloc x a_nnei x (a_nnei) x e_dim [OR] n_angle x e_dim
994946
edge_for_angle_j = (
995-
torch.tile(
996-
edge_ebd_for_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1)
997-
)
947+
torch.tile(edge_ebd_for_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1))
998948
if not self.use_dynamic_sel
999949
else torch.index_select(edge_ebd_for_angle, 0, eij2a_index)
1000950
)
1001951
# nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) [OR] n_angle x (e_dim + e_dim)
1002-
edge_for_angle_info = torch.cat(
1003-
[edge_for_angle_k, edge_for_angle_j], dim=-1
1004-
)
952+
edge_for_angle_info = torch.cat([edge_for_angle_k, edge_for_angle_j], dim=-1)
1005953
angle_info_list = [angle_ebd]
1006954
angle_info_list.append(node_for_angle_info)
1007955
angle_info_list.append(edge_for_angle_info)
@@ -1039,15 +987,9 @@ def forward(
1039987

1040988
if not self.use_dynamic_sel:
1041989
# nb x nloc x a_nnei x a_nnei x e_dim
1042-
weighted_edge_angle_update = (
1043-
a_sw.unsqueeze(-1).unsqueeze(-1)
1044-
* a_sw.unsqueeze(-2).unsqueeze(-1)
1045-
* edge_angle_update
1046-
)
990+
weighted_edge_angle_update = a_sw.unsqueeze(-1).unsqueeze(-1) * a_sw.unsqueeze(-2).unsqueeze(-1) * edge_angle_update
1047991
# nb x nloc x a_nnei x e_dim
1048-
reduced_edge_angle_update = torch.sum(
1049-
weighted_edge_angle_update, dim=-2
1050-
) / (self.a_sel**0.5)
992+
reduced_edge_angle_update = torch.sum(weighted_edge_angle_update, dim=-2) / (self.a_sel**0.5)
1051993
# nb x nloc x nnei x e_dim
1052994
padding_edge_angle_update = torch.concat(
1053995
[
@@ -1075,9 +1017,7 @@ def forward(
10751017
# will be deprecated in the future
10761018
# not support dynamic index, will pass anyway
10771019
if self.use_dynamic_sel:
1078-
raise NotImplementedError(
1079-
"smooth_edge_update must be True when use_dynamic_sel is True!"
1080-
)
1020+
raise NotImplementedError("smooth_edge_update must be True when use_dynamic_sel is True!")
10811021
full_mask = torch.concat(
10821022
[
10831023
a_nlist_mask,
@@ -1089,12 +1029,8 @@ def forward(
10891029
],
10901030
dim=-1,
10911031
)
1092-
padding_edge_angle_update = torch.where(
1093-
full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd
1094-
)
1095-
e_update_list.append(
1096-
self.act(self.edge_angle_linear2(padding_edge_angle_update))
1097-
)
1032+
padding_edge_angle_update = torch.where(full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd)
1033+
e_update_list.append(self.act(self.edge_angle_linear2(padding_edge_angle_update)))
10981034
# update edge_ebd
10991035
e_updated = self.list_update(e_update_list, "edge")
11001036

@@ -1152,9 +1088,7 @@ def list_update_res_incr(self, update_list: list[torch.Tensor]) -> torch.Tensor:
11521088
return uu
11531089

11541090
@torch.jit.export
1155-
def list_update_res_residual(
1156-
self, update_list: list[torch.Tensor], update_name: str = "node"
1157-
) -> torch.Tensor:
1091+
def list_update_res_residual(self, update_list: list[torch.Tensor], update_name: str = "node") -> torch.Tensor:
11581092
nitem = len(update_list)
11591093
uu = update_list[0]
11601094
# make jit happy
@@ -1172,9 +1106,7 @@ def list_update_res_residual(
11721106
return uu
11731107

11741108
@torch.jit.export
1175-
def list_update(
1176-
self, update_list: list[torch.Tensor], update_name: str = "node"
1177-
) -> torch.Tensor:
1109+
def list_update(self, update_list: list[torch.Tensor], update_name: str = "node") -> torch.Tensor:
11781110
if self.update_style == "res_avg":
11791111
return self.list_update_res_avg(update_list)
11801112
elif self.update_style == "res_incr":

deepmd/pt/model/descriptor/repflows.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -472,17 +472,6 @@ def forward(
472472
# beyond the cutoff sw should be 0.0
473473
sw = sw.masked_fill(~nlist_mask, 0.0)
474474

475-
# [nframes, nloc, tebd_dim]
476-
assert extended_atype_embd is not None
477-
atype_embd = extended_atype_embd[:, :nloc, :]
478-
assert list(atype_embd.shape) == [nframes, nloc, self.n_dim]
479-
node_ebd = self.act(atype_embd)
480-
n_dim = node_ebd.shape[-1]
481-
# nb x nloc x nnei x 1, nb x nloc x nnei x 3
482-
edge_input, h2 = torch.split(dmatrix, [1, 3], dim=-1)
483-
# nb x nloc x nnei x e_dim
484-
edge_ebd = self.act(self.edge_embd(edge_input))
485-
486475
# get angle nlist (maybe smaller)
487476
a_dist_mask = (torch.linalg.norm(diff, dim=-1) < self.a_rcut)[
488477
:, :, : self.a_sel

deepmd/pt/model/network/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Optional,
4+
)
25

36
import torch
47

0 commit comments

Comments
 (0)