Skip to content

Commit 81d904f

Browse files
authored
Add MultiChainPermutationAlignment module (#191)
* Update alphafold3.py * Update model_utils.py * Update alphafold3.py * Update __init__.py * Update test_af3.py * Update README.md * Update test_af3.py * Update test_af3.py * Update alphafold3.py * Update alphafold3.py * Update alphafold3.py
1 parent 57b1322 commit 81d904f

File tree

5 files changed

+1488
-285
lines changed

5 files changed

+1488
-285
lines changed

README.md

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ $ pip install alphafold3-pytorch
6565
```python
6666
import torch
6767
from alphafold3_pytorch import Alphafold3
68+
from alphafold3_pytorch.utils.model_utils import exclusive_cumsum
6869

6970
alphafold3 = Alphafold3(
7071
dim_atom_inputs = 77,
@@ -74,8 +75,12 @@ alphafold3 = Alphafold3(
7475
# mock inputs
7576

7677
seq_len = 16
77-
molecule_atom_lens = torch.randint(1, 3, (2, seq_len))
78-
atom_seq_len = molecule_atom_lens.sum(dim = -1).amax()
78+
79+
molecule_atom_indices = torch.randint(0, 2, (2, seq_len)).long()
80+
molecule_atom_lens = torch.full((2, seq_len), 2).long()
81+
82+
atom_seq_len = molecule_atom_lens.sum(dim=-1).amax()
83+
atom_offsets = exclusive_cumsum(molecule_atom_lens)
7984

8085
atom_inputs = torch.randn(2, atom_seq_len, 77)
8186
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
@@ -98,12 +103,16 @@ additional_msa_feats = torch.randn(2, 7, seq_len, 2)
98103

99104
atom_pos = torch.randn(2, atom_seq_len, 3)
100105

101-
molecule_atom_indices = molecule_atom_lens - 1 # last atom, as an example
102-
molecule_atom_indices += (molecule_atom_lens.cumsum(dim = -1) - molecule_atom_lens)
106+
distogram_atom_indices = molecule_atom_lens - 1
103107

104108
distance_labels = torch.randint(0, 37, (2, seq_len, seq_len))
105109
resolved_labels = torch.randint(0, 2, (2, atom_seq_len))
106110

111+
# offset indices correctly
112+
113+
distogram_atom_indices += atom_offsets
114+
molecule_atom_indices += atom_offsets
115+
107116
# train
108117

109118
loss = alphafold3(
@@ -122,6 +131,7 @@ loss = alphafold3(
122131
templates = template_feats,
123132
template_mask = template_mask,
124133
atom_pos = atom_pos,
134+
distogram_atom_indices = distogram_atom_indices,
125135
molecule_atom_indices = molecule_atom_indices,
126136
distance_labels = distance_labels,
127137
resolved_labels = resolved_labels

alphafold3_pytorch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
RelativePositionEncoding,
99
SmoothLDDTLoss,
1010
WeightedRigidAlign,
11+
MultiChainPermutationAlignment,
1112
ExpressCoordinatesInFrame,
1213
RigidFrom3Points,
1314
ComputeAlignmentError,
@@ -76,6 +77,7 @@
7677
RelativePositionEncoding,
7778
SmoothLDDTLoss,
7879
WeightedRigidAlign,
80+
MultiChainPermutationAlignment,
7981
ExpressCoordinatesInFrame,
8082
ComputeAlignmentError,
8183
CentreRandomAugmentation,

0 commit comments

Comments
 (0)