Skip to content

Commit d1d7445

Browse files
committed
lint and cleanup
1 parent e9d4f9e commit d1d7445

File tree

4 files changed

+46
-66
lines changed

4 files changed

+46
-66
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 38 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,5 @@
11
from __future__ import annotations
22

3-
"""
4-
global ein notation:
5-
6-
b - batch
7-
ba - batch with augmentation
8-
h - heads
9-
n - residue sequence length
10-
i - residue sequence length (source)
11-
j - residue sequence length (target)
12-
m - atom sequence length
13-
nw - windowed sequence length
14-
d - feature dimension
15-
ds - feature dimension (single)
16-
dp - feature dimension (pairwise)
17-
dap - feature dimension (atompair)
18-
dapi - feature dimension (atompair input)
19-
da - feature dimension (atom)
20-
dai - feature dimension (atom input)
21-
t - templates
22-
s - msa
23-
r - registers
24-
"""
25-
26-
"""
27-
additional_residue_feats: [*, 10]:
28-
29-
0: residue_index
30-
1: token_index
31-
2: asym_id
32-
3: entity_id
33-
4: sym_id
34-
5: restype (must be one hot encoded to 32)
35-
6: is_protein
36-
7: is_rna
37-
8: is_dna
38-
9: is_ligand
39-
"""
40-
413
from math import pi, sqrt
424
from pathlib import Path
435
from functools import partial, wraps
@@ -85,6 +47,44 @@
8547

8648
from importlib.metadata import version
8749

50+
"""
51+
global ein notation:
52+
53+
b - batch
54+
ba - batch with augmentation
55+
h - heads
56+
n - residue sequence length
57+
i - residue sequence length (source)
58+
j - residue sequence length (target)
59+
m - atom sequence length
60+
nw - windowed sequence length
61+
d - feature dimension
62+
ds - feature dimension (single)
63+
dp - feature dimension (pairwise)
64+
dap - feature dimension (atompair)
65+
dapi - feature dimension (atompair input)
66+
da - feature dimension (atom)
67+
dai - feature dimension (atom input)
68+
t - templates
69+
s - msa
70+
r - registers
71+
"""
72+
73+
"""
74+
additional_residue_feats: [*, 10]:
75+
76+
0: residue_index
77+
1: token_index
78+
2: asym_id
79+
3: entity_id
80+
4: sym_id
81+
5: restype (must be one hot encoded to 32)
82+
6: is_protein
83+
7: is_rna
84+
8: is_dna
85+
9: is_ligand
86+
"""
87+
8888
# constants
8989

9090
ADDITIONAL_RESIDUE_FEATS = 10
@@ -241,24 +241,6 @@ def repeat_consecutive_with_lens(
241241

242242
return output
243243

244-
def repeat_pairwise_consecutive_with_lens(
245-
feats: Float['b n n dp'],
246-
lens: Int['b n']
247-
) -> Float['b m m dp']:
248-
249-
repeated_lens = repeat(lens, 'b ... -> (b repeat) ...', repeat = feats.shape[1])
250-
feats, ps = pack_one(feats, '* n dp')
251-
feats = repeat_consecutive_with_lens(feats, repeated_lens)
252-
feats = unpack_one(feats, ps, '* n dp')
253-
254-
feats = rearrange(feats, 'b i j dp -> b j i dp')
255-
repeated_lens = repeat(lens, 'b ... -> (b repeat) ...', repeat = feats.shape[1])
256-
feats, ps = pack_one(feats, '* n dp')
257-
feats = repeat_consecutive_with_lens(feats, repeated_lens)
258-
feats = unpack_one(feats, ps, '* n dp')
259-
feats = rearrange(feats, 'b j i dp -> b i j dp')
260-
return feats
261-
262244
# linear and outer sum
263245
# for single repr -> pairwise pattern throughout this architecture
264246

@@ -2105,8 +2087,6 @@ def forward(
21052087
align_weights = atom_pos_ground_truth.new_ones(atom_pos_ground_truth.shape[:2])
21062088

21072089
if exists(additional_residue_feats):
2108-
w = self.net.atoms_per_window
2109-
21102090
is_nucleotide_or_ligand_fields = (additional_residue_feats[..., 7:] != 0.).unbind(dim = -1)
21112091

21122092
is_nucleotide_or_ligand_fields = tuple(repeat_consecutive_with_lens(t, residue_atom_lens) for t in is_nucleotide_or_ligand_fields)

alphafold3_pytorch/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import NamedTuple, Tuple
33

44
import torch
5-
from torch import nn
5+
from torch import nn, Tensor
66
import torch.nn.functional as F
77
from torch.nn import Module
88

alphafold3_pytorch/trainer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ class Alphafold3Input(TypedDict):
2929
additional_residue_feats: Float['n 10']
3030
templates: Float['t n n dt']
3131
msa: Float['s n dm']
32-
template_mask: Bool['t'] | None
33-
msa_mask: Bool['s'] | None
32+
template_mask: Bool[' t'] | None
33+
msa_mask: Bool[' s'] | None
3434
atom_pos: Float['m 3'] | None
35-
residue_atom_indices: Int['n'] | None
35+
residue_atom_indices: Int[' n'] | None
3636
distance_labels: Int['n n'] | None
3737
pae_labels: Int['n n'] | None
38-
pde_labels: Int['n'] | None
39-
resolved_labels: Int['n'] | None
38+
pde_labels: Int[' n'] | None
39+
resolved_labels: Int[' n'] | None
4040

4141
# helpers
4242

@@ -406,4 +406,4 @@ def __call__(
406406

407407
self.log(**test_loss_breakdown)
408408

409-
print(f'training complete')
409+
print('training complete')

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.1"
3+
version = "0.1.2"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)