Skip to content

Commit 83955fc

Browse files
committed
some cleanup
1 parent cc8ee0c commit 83955fc

File tree

1 file changed

+14
-27
lines changed

1 file changed

+14
-27
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2102,13 +2102,12 @@ def _random_translation_vector(self, device: torch.device) -> Float['3']:
21022102

21032103
# input embedder
21042104

2105-
EmbeddedInputs = namedtuple('EmbeddedInputs', [
2106-
'single_inputs',
2107-
'single_init',
2108-
'pairwise_init',
2109-
'atom_feats',
2110-
'atompair_feats'
2111-
])
2105+
class EmbeddedInputs(NamedTuple):
2106+
single_inputs: Float['b n ds']
2107+
single_init: Float['b n ds']
2108+
pairwise_init: Float['b n n dp']
2109+
atom_feats: Float['b m da']
2110+
atompair_feats: Float['b m m dap']
21122111

21132112
class InputFeatureEmbedder(Module):
21142113
""" Algorithm 2 """
@@ -2126,7 +2125,7 @@ def __init__(
21262125
dim_pairwise = 128,
21272126
atom_transformer_blocks = 3,
21282127
atom_transformer_heads = 4,
2129-
atom_transformer_kwargs: dict = dict()
2128+
atom_transformer_kwargs: dict = dict(),
21302129
):
21312130
super().__init__()
21322131
self.atoms_per_window = atoms_per_window
@@ -2178,13 +2177,7 @@ def forward(
21782177
atom_mask: Bool['b m'],
21792178
atompair_feats: Float['b m m dap'],
21802179
additional_residue_feats: Float['b n rf'],
2181-
) -> EmbeddedInputs[
2182-
Float['b n ds'],
2183-
Float['b n ds'],
2184-
Float['b n n dp'],
2185-
Float['b m da'],
2186-
Float['b m m dap']
2187-
]:
2180+
) -> EmbeddedInputs:
21882181

21892182
assert additional_residue_feats.shape[-1] == self.dim_additional_residue_feats
21902183

@@ -2244,12 +2237,11 @@ def forward(
22442237

22452238
# confidence head
22462239

2247-
ConfidenceHeadLogits = namedtuple('ConfidenceHeadLogits', [
2248-
'pae',
2249-
'pde',
2250-
'plddt',
2251-
'resolved'
2252-
])
2240+
class ConfidenceHeadLogits(NamedTuple):
2241+
pae: Float['b pae n n'] | None
2242+
pde: Float['b pde n n']
2243+
plddt: Float['b plddt n']
2244+
resolved: Float['b 2 n']
22532245

22542246
class ConfidenceHead(Module):
22552247
""" Algorithm 31 """
@@ -2320,12 +2312,7 @@ def forward(
23202312
mask: Bool['b n'] | None = None,
23212313
return_pae_logits = True
23222314

2323-
) -> ConfidenceHeadLogits[
2324-
Float['b pae n n'] | None,
2325-
Float['b pde n n'],
2326-
Float['b plddt n'],
2327-
Float['b resolved n']
2328-
]:
2315+
) -> ConfidenceHeadLogits:
23292316

23302317
pairwise_repr = pairwise_repr + self.single_inputs_to_pairwise(single_inputs_repr)
23312318

0 commit comments

Comments
 (0)