@@ -2102,13 +2102,12 @@ def _random_translation_vector(self, device: torch.device) -> Float['3']:
21022102
21032103# input embedder
21042104
2105- EmbeddedInputs = namedtuple ('EmbeddedInputs' , [
2106- 'single_inputs' ,
2107- 'single_init' ,
2108- 'pairwise_init' ,
2109- 'atom_feats' ,
2110- 'atompair_feats'
2111- ])
2105+ class EmbeddedInputs (NamedTuple ):
2106+ single_inputs : Float ['b n ds' ]
2107+ single_init : Float ['b n ds' ]
2108+ pairwise_init : Float ['b n n dp' ]
2109+ atom_feats : Float ['b m da' ]
2110+ atompair_feats : Float ['b m m dap' ]
21122111
21132112class InputFeatureEmbedder (Module ):
21142113 """ Algorithm 2 """
@@ -2126,7 +2125,7 @@ def __init__(
21262125 dim_pairwise = 128 ,
21272126 atom_transformer_blocks = 3 ,
21282127 atom_transformer_heads = 4 ,
2129- atom_transformer_kwargs : dict = dict ()
2128+ atom_transformer_kwargs : dict = dict (),
21302129 ):
21312130 super ().__init__ ()
21322131 self .atoms_per_window = atoms_per_window
@@ -2178,13 +2177,7 @@ def forward(
21782177 atom_mask : Bool ['b m' ],
21792178 atompair_feats : Float ['b m m dap' ],
21802179 additional_residue_feats : Float ['b n rf' ],
2181- ) -> EmbeddedInputs [
2182- Float ['b n ds' ],
2183- Float ['b n ds' ],
2184- Float ['b n n dp' ],
2185- Float ['b m da' ],
2186- Float ['b m m dap' ]
2187- ]:
2180+ ) -> EmbeddedInputs :
21882181
21892182 assert additional_residue_feats .shape [- 1 ] == self .dim_additional_residue_feats
21902183
@@ -2244,12 +2237,11 @@ def forward(
22442237
22452238# confidence head
22462239
2247- ConfidenceHeadLogits = namedtuple ('ConfidenceHeadLogits' , [
2248- 'pae' ,
2249- 'pde' ,
2250- 'plddt' ,
2251- 'resolved'
2252- ])
2240+ class ConfidenceHeadLogits (NamedTuple ):
2241+ pae : Float ['b pae n n' ] | None
2242+ pde : Float ['b pde n n' ]
2243+ plddt : Float ['b plddt n' ]
2244+ resolved : Float ['b 2 n' ]
22532245
22542246class ConfidenceHead (Module ):
22552247 """ Algorithm 31 """
@@ -2320,12 +2312,7 @@ def forward(
23202312 mask : Bool ['b n' ] | None = None ,
23212313 return_pae_logits = True
23222314
2323- ) -> ConfidenceHeadLogits [
2324- Float ['b pae n n' ] | None ,
2325- Float ['b pde n n' ],
2326- Float ['b plddt n' ],
2327- Float ['b resolved n' ]
2328- ]:
2315+ ) -> ConfidenceHeadLogits :
23292316
23302317 pairwise_repr = pairwise_repr + self .single_inputs_to_pairwise (single_inputs_repr )
23312318
0 commit comments