@@ -210,12 +210,8 @@ def __init__(
210210 self .n_a_compress_dim = n_dim
211211 else :
212212 # angle + a_dim/c + a_dim/2c * 2 * e_rate
213- self .angle_dim += (1 + self .a_compress_e_rate ) * (
214- self .a_dim // self .a_compress_rate
215- )
216- self .e_a_compress_dim = (
217- self .a_dim // (2 * self .a_compress_rate ) * self .a_compress_e_rate
218- )
213+ self .angle_dim += (1 + self .a_compress_e_rate ) * (self .a_dim // self .a_compress_rate )
214+ self .e_a_compress_dim = self .a_dim // (2 * self .a_compress_rate ) * self .a_compress_e_rate
219215 self .n_a_compress_dim = self .a_dim // self .a_compress_rate
220216 if not self .a_compress_use_split :
221217 self .a_compress_n_linear = MLPLayer (
@@ -331,10 +327,7 @@ def _cal_hg(
331327 # nb x nloc x nnei x e_dim
332328 edge_ebd = _apply_nlist_mask (edge_ebd , nlist_mask )
333329 edge_ebd = _apply_switch (edge_ebd , sw )
334- invnnei = torch .rsqrt (
335- float (nnei )
336- * torch .ones ((nb , nloc , 1 , 1 ), dtype = edge_ebd .dtype , device = edge_ebd .device )
337- )
330+ invnnei = torch .rsqrt (float (nnei ) * torch .ones ((nb , nloc , 1 , 1 ), dtype = edge_ebd .dtype , device = edge_ebd .device ))
338331 # nb x nloc x 3 x e_dim
339332 h2g2 = torch .matmul (torch .transpose (h2 , - 1 , - 2 ), edge_ebd ) * invnnei
340333 return h2g2
@@ -382,16 +375,9 @@ def _cal_hg_dynamic(
382375 # n_edge x e_dim
383376 flat_edge_ebd = flat_edge_ebd * flat_sw .unsqueeze (- 1 )
384377 # n_edge x 3 x e_dim
385- flat_h2g2 = (flat_h2 .unsqueeze (- 1 ) * flat_edge_ebd .unsqueeze (- 2 )).reshape (
386- - 1 , 3 * e_dim
387- )
378+ flat_h2g2 = (flat_h2 .unsqueeze (- 1 ) * flat_edge_ebd .unsqueeze (- 2 )).reshape (- 1 , 3 * e_dim )
388379 # nf x nloc x 3 x e_dim
389- h2g2 = (
390- aggregate (flat_h2g2 , owner , average = False , num_owner = num_owner ).reshape (
391- nb , nloc , 3 , e_dim
392- )
393- * scale_factor
394- )
380+ h2g2 = aggregate (flat_h2g2 , owner , average = False , num_owner = num_owner ).reshape (nb , nloc , 3 , e_dim ) * scale_factor
395381 return h2g2
396382
397383 @staticmethod
@@ -544,9 +530,7 @@ def optim_angle_update(
544530 node_dim = node_ebd .shape [- 1 ]
545531 edge_dim = edge_ebd .shape [- 1 ]
546532 # angle_dim, node_dim, edge_dim, edge_dim
547- sub_angle , sub_node , sub_edge_ik , sub_edge_ij = torch .split (
548- matrix , [angle_dim , node_dim , edge_dim , edge_dim ]
549- )
533+ sub_angle , sub_node , sub_edge_ik , sub_edge_ij = torch .split (matrix , [angle_dim , node_dim , edge_dim , edge_dim ])
550534
551535 # nf * nloc * a_sel * a_sel * angle_dim
552536 sub_angle_update = torch .matmul (angle_ebd , sub_angle )
@@ -557,11 +541,7 @@ def optim_angle_update(
557541 sub_edge_update_ij = torch .matmul (edge_ebd , sub_edge_ij )
558542
559543 result_update = (
560- bias
561- + sub_node_update .unsqueeze (2 ).unsqueeze (3 )
562- + sub_edge_update_ik .unsqueeze (2 )
563- + sub_edge_update_ij .unsqueeze (3 )
564- + sub_angle_update
544+ bias + sub_node_update .unsqueeze (2 ).unsqueeze (3 ) + sub_edge_update_ik .unsqueeze (2 ) + sub_edge_update_ij .unsqueeze (3 ) + sub_angle_update
565545 )
566546 return result_update
567547
@@ -585,19 +565,15 @@ def optim_angle_update_dynamic(
585565 edge_dim = flat_edge_ebd .shape [- 1 ]
586566 angle_dim = flat_angle_ebd .shape [- 1 ]
587567 # angle_dim, node_dim, edge_dim, edge_dim
588- sub_angle , sub_node , sub_edge_ik , sub_edge_ij = torch .split (
589- matrix , [angle_dim , node_dim , edge_dim , edge_dim ]
590- )
568+ sub_angle , sub_node , sub_edge_ik , sub_edge_ij = torch .split (matrix , [angle_dim , node_dim , edge_dim , edge_dim ])
591569
592570 # n_angle * angle_dim
593571 sub_angle_update = torch .matmul (flat_angle_ebd , sub_angle )
594572
595573 # nf * nloc * angle_dim
596574 sub_node_update = torch .matmul (node_ebd , sub_node )
597575 # n_angle * angle_dim
598- sub_node_update = torch .index_select (
599- sub_node_update .reshape (nf * nloc , sub_node_update .shape [- 1 ]), 0 , n2a_index
600- )
576+ sub_node_update = torch .index_select (sub_node_update .reshape (nf * nloc , sub_node_update .shape [- 1 ]), 0 , n2a_index )
601577
602578 # n_edge * angle_dim
603579 sub_edge_update_ik = torch .matmul (flat_edge_ebd , sub_edge_ik )
@@ -606,13 +582,7 @@ def optim_angle_update_dynamic(
606582 sub_edge_update_ik = torch .index_select (sub_edge_update_ik , 0 , eik2a_index )
607583 sub_edge_update_ij = torch .index_select (sub_edge_update_ij , 0 , eij2a_index )
608584
609- result_update = (
610- bias
611- + sub_node_update
612- + sub_edge_update_ik
613- + sub_edge_update_ij
614- + sub_angle_update
615- )
585+ result_update = bias + sub_node_update + sub_edge_update_ik + sub_edge_update_ij + sub_angle_update
616586 return result_update
617587
618588 def optim_edge_update (
@@ -645,9 +615,7 @@ def optim_edge_update(
645615 # nf * nloc * nnei * node/edge_dim
646616 sub_edge_update = torch .matmul (edge_ebd , edge )
647617
648- result_update = (
649- bias + sub_node_update .unsqueeze (2 ) + sub_edge_update + sub_node_ext_update
650- )
618+ result_update = bias + sub_node_update .unsqueeze (2 ) + sub_edge_update + sub_node_ext_update
651619 return result_update
652620
653621 def optim_edge_update_dynamic (
@@ -675,9 +643,7 @@ def optim_edge_update_dynamic(
675643 # nf * nloc * node/edge_dim
676644 sub_node_update = torch .matmul (node_ebd , node )
677645 # n_edge * node/edge_dim
678- sub_node_update = torch .index_select (
679- sub_node_update .reshape (nf * nloc , sub_node_update .shape [- 1 ]), 0 , n2e_index
680- )
646+ sub_node_update = torch .index_select (sub_node_update .reshape (nf * nloc , sub_node_update .shape [- 1 ]), 0 , n2e_index )
681647
682648 # nf * nall * node/edge_dim
683649 sub_node_ext_update = torch .matmul (node_ebd_ext , node_ext )
@@ -757,14 +723,12 @@ def forward(
757723 nb , nloc , nnei = nlist .shape
758724 nall = node_ebd_ext .shape [1 ]
759725 node_ebd = node_ebd_ext [:, :nloc , :]
726+ n_edge = int (nlist_mask .sum ().item ()) if self .use_dynamic_sel else 0
760727 assert (nb , nloc ) == node_ebd .shape [:2 ]
761728 if not self .use_dynamic_sel :
762729 assert (nb , nloc , nnei , 3 ) == h2 .shape
763- n_edge = None
764730 else :
765- # n_edge = int(nlist_mask.sum().item())
766- # assert (n_edge, 3) == h2.shape
767- n_edge = h2 .shape [0 ]
731+ assert (n_edge , 3 ) == h2 .shape
768732 del a_nlist # may be used in the future
769733
770734 n2e_index , n_ext2e_index = edge_index [0 ], edge_index [1 ]
@@ -778,9 +742,7 @@ def forward(
778742 nei_node_ebd = (
779743 _make_nei_g1 (node_ebd_ext , nlist )
780744 if not self .use_dynamic_sel
781- else torch .index_select (
782- node_ebd_ext .reshape (- 1 , self .n_dim ), 0 , n_ext2e_index
783- )
745+ else torch .index_select (node_ebd_ext .reshape (- 1 , self .n_dim ), 0 , n_ext2e_index )
784746 )
785747
786748 n_update_list : list [torch .Tensor ] = [node_ebd ]
@@ -853,9 +815,7 @@ def forward(
853815 # n_edge x (n_dim * 2 + e_dim)
854816 edge_info = torch .cat (
855817 [
856- torch .index_select (
857- node_ebd .reshape (- 1 , self .n_dim ), 0 , n2e_index
858- ),
818+ torch .index_select (node_ebd .reshape (- 1 , self .n_dim ), 0 , n2e_index ),
859819 nei_node_ebd ,
860820 edge_ebd ,
861821 ],
@@ -868,9 +828,7 @@ def forward(
868828 # nb x nloc x nnei x (h * n_dim)
869829 if not self .optim_update :
870830 assert edge_info is not None
871- node_edge_update = self .act (
872- self .node_edge_linear (edge_info )
873- ) * sw .unsqueeze (- 1 )
831+ node_edge_update = self .act (self .node_edge_linear (edge_info )) * sw .unsqueeze (- 1 )
874832 else :
875833 node_edge_update = self .act (
876834 self .optim_edge_update (
@@ -906,9 +864,7 @@ def forward(
906864
907865 if self .n_multi_edge_message > 1 :
908866 # nb x nloc x h x n_dim
909- node_edge_update_mul_head = node_edge_update .view (
910- nb , nloc , self .n_multi_edge_message , self .n_dim
911- )
867+ node_edge_update_mul_head = node_edge_update .view (nb , nloc , self .n_multi_edge_message , self .n_dim )
912868 for head_index in range (self .n_multi_edge_message ):
913869 n_update_list .append (node_edge_update_mul_head [..., head_index , :])
914870 else :
@@ -964,9 +920,7 @@ def forward(
964920 # nb x nloc x a_nnei x e_dim
965921 edge_ebd_for_angle = edge_ebd_for_angle [..., : self .a_sel , :]
966922 # nb x nloc x a_nnei x e_dim
967- edge_ebd_for_angle = torch .where (
968- a_nlist_mask .unsqueeze (- 1 ), edge_ebd_for_angle , 0.0
969- )
923+ edge_ebd_for_angle = torch .where (a_nlist_mask .unsqueeze (- 1 ), edge_ebd_for_angle , 0.0 )
970924 if not self .optim_update :
971925 # nb x nloc x a_nnei x a_nnei x n_dim [OR] n_angle x n_dim
972926 node_for_angle_info = (
@@ -984,24 +938,18 @@ def forward(
984938
985939 # nb x nloc x (a_nnei) x a_nnei x e_dim [OR] n_angle x e_dim
986940 edge_for_angle_k = (
987- torch .tile (
988- edge_ebd_for_angle .unsqueeze (2 ), (1 , 1 , self .a_sel , 1 , 1 )
989- )
941+ torch .tile (edge_ebd_for_angle .unsqueeze (2 ), (1 , 1 , self .a_sel , 1 , 1 ))
990942 if not self .use_dynamic_sel
991943 else torch .index_select (edge_ebd_for_angle , 0 , eik2a_index )
992944 )
993945 # nb x nloc x a_nnei x (a_nnei) x e_dim [OR] n_angle x e_dim
994946 edge_for_angle_j = (
995- torch .tile (
996- edge_ebd_for_angle .unsqueeze (3 ), (1 , 1 , 1 , self .a_sel , 1 )
997- )
947+ torch .tile (edge_ebd_for_angle .unsqueeze (3 ), (1 , 1 , 1 , self .a_sel , 1 ))
998948 if not self .use_dynamic_sel
999949 else torch .index_select (edge_ebd_for_angle , 0 , eij2a_index )
1000950 )
1001951 # nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) [OR] n_angle x (e_dim + e_dim)
1002- edge_for_angle_info = torch .cat (
1003- [edge_for_angle_k , edge_for_angle_j ], dim = - 1
1004- )
952+ edge_for_angle_info = torch .cat ([edge_for_angle_k , edge_for_angle_j ], dim = - 1 )
1005953 angle_info_list = [angle_ebd ]
1006954 angle_info_list .append (node_for_angle_info )
1007955 angle_info_list .append (edge_for_angle_info )
@@ -1039,15 +987,9 @@ def forward(
1039987
1040988 if not self .use_dynamic_sel :
1041989 # nb x nloc x a_nnei x a_nnei x e_dim
1042- weighted_edge_angle_update = (
1043- a_sw .unsqueeze (- 1 ).unsqueeze (- 1 )
1044- * a_sw .unsqueeze (- 2 ).unsqueeze (- 1 )
1045- * edge_angle_update
1046- )
990+ weighted_edge_angle_update = a_sw .unsqueeze (- 1 ).unsqueeze (- 1 ) * a_sw .unsqueeze (- 2 ).unsqueeze (- 1 ) * edge_angle_update
1047991 # nb x nloc x a_nnei x e_dim
1048- reduced_edge_angle_update = torch .sum (
1049- weighted_edge_angle_update , dim = - 2
1050- ) / (self .a_sel ** 0.5 )
992+ reduced_edge_angle_update = torch .sum (weighted_edge_angle_update , dim = - 2 ) / (self .a_sel ** 0.5 )
1051993 # nb x nloc x nnei x e_dim
1052994 padding_edge_angle_update = torch .concat (
1053995 [
@@ -1075,9 +1017,7 @@ def forward(
10751017 # will be deprecated in the future
10761018 # not support dynamic index, will pass anyway
10771019 if self .use_dynamic_sel :
1078- raise NotImplementedError (
1079- "smooth_edge_update must be True when use_dynamic_sel is True!"
1080- )
1020+ raise NotImplementedError ("smooth_edge_update must be True when use_dynamic_sel is True!" )
10811021 full_mask = torch .concat (
10821022 [
10831023 a_nlist_mask ,
@@ -1089,12 +1029,8 @@ def forward(
10891029 ],
10901030 dim = - 1 ,
10911031 )
1092- padding_edge_angle_update = torch .where (
1093- full_mask .unsqueeze (- 1 ), padding_edge_angle_update , edge_ebd
1094- )
1095- e_update_list .append (
1096- self .act (self .edge_angle_linear2 (padding_edge_angle_update ))
1097- )
1032+ padding_edge_angle_update = torch .where (full_mask .unsqueeze (- 1 ), padding_edge_angle_update , edge_ebd )
1033+ e_update_list .append (self .act (self .edge_angle_linear2 (padding_edge_angle_update )))
10981034 # update edge_ebd
10991035 e_updated = self .list_update (e_update_list , "edge" )
11001036
@@ -1152,9 +1088,7 @@ def list_update_res_incr(self, update_list: list[torch.Tensor]) -> torch.Tensor:
11521088 return uu
11531089
11541090 @torch .jit .export
1155- def list_update_res_residual (
1156- self , update_list : list [torch .Tensor ], update_name : str = "node"
1157- ) -> torch .Tensor :
1091+ def list_update_res_residual (self , update_list : list [torch .Tensor ], update_name : str = "node" ) -> torch .Tensor :
11581092 nitem = len (update_list )
11591093 uu = update_list [0 ]
11601094 # make jit happy
@@ -1172,9 +1106,7 @@ def list_update_res_residual(
11721106 return uu
11731107
11741108 @torch .jit .export
1175- def list_update (
1176- self , update_list : list [torch .Tensor ], update_name : str = "node"
1177- ) -> torch .Tensor :
1109+ def list_update (self , update_list : list [torch .Tensor ], update_name : str = "node" ) -> torch .Tensor :
11781110 if self .update_style == "res_avg" :
11791111 return self .list_update_res_avg (update_list )
11801112 elif self .update_style == "res_incr" :
0 commit comments