Skip to content

Commit 49bb8a5

Browse files
committed
fix readme examples and add to main test
1 parent 3b445eb commit 49bb8a5

File tree

5 files changed

+160
-5
lines changed

5 files changed

+160
-5
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
strategy:
1414
fail-fast: false
1515
matrix:
16-
group: [0, 1, 2, 3, 4]
16+
group: [0, 1, 2, 3, 4, 5, 6, 7]
1717

1818
steps:
1919
- uses: actions/checkout@v4
@@ -28,4 +28,4 @@ jobs:
2828
python -m pip install --default-timeout=100 -e .[test]
2929
- name: Test with pytest
3030
run: |
31-
python -m pytest --num-shards 5 --shard-id ${{ matrix.group }} tests/
31+
python -m pytest --num-shards 8 --shard-id ${{ matrix.group }} tests/

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ additional_msa_feats = torch.randn(2, 7, seq_len, 2)
9797
# required for training, but omitted on inference
9898

9999
atom_pos = torch.randn(2, atom_seq_len, 3)
100+
100101
molecule_atom_indices = molecule_atom_lens - 1 # last atom, as an example
102+
molecule_atom_indices += (molecule_atom_lens.cumsum(dim = -1) - molecule_atom_lens)
101103

102104
distance_labels = torch.randint(0, 37, (2, seq_len, seq_len))
103105
resolved_labels = torch.randint(0, 2, (2, atom_seq_len))

alphafold3_pytorch/alphafold3.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5692,7 +5692,12 @@ def forward(
56925692

56935693
has_labels = any([*map(exists, all_labels)])
56945694

5695-
can_return_loss = atom_pos_given or has_labels
5695+
can_return_loss = (
5696+
atom_pos_given or
5697+
exists(resolved_labels) or
5698+
exists(distance_labels) or
5699+
(atom_pos_given and exists(atom_indices_for_frame))
5700+
)
56965701

56975702
# default whether to return loss by whether labels or atom positions are given
56985703

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.4.2"
3+
version = "0.4.3"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

tests/test_af3.py

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
ConfidenceHeadLogits,
3737
ComputeModelSelectionScore,
3838
ComputeModelSelectionScore,
39-
collate_inputs_to_batched_atom_input
39+
collate_inputs_to_batched_atom_input,
40+
alphafold3_inputs_to_batched_atom_input,
4041
)
4142

4243
from alphafold3_pytorch.mocks import MockAtomDataset
@@ -61,6 +62,7 @@
6162
molecule_to_atom_input,
6263
pdb_input_to_molecule_input,
6364
PDBInput,
65+
Alphafold3Input,
6466
PDBDataset,
6567
default_extract_atom_feats_fn,
6668
default_extract_atompair_feats_fn,
@@ -1226,3 +1228,149 @@ def test_unresolved_protein_rasa():
12261228
molecule_atom_lens=batched_atom_input_dict['molecule_atom_lens'],
12271229
atom_pos=batched_atom_input_dict['atom_pos'],
12281230
atom_mask=~batched_atom_input_dict['missing_atom_mask'])
1231+
1232+
def test_readme1():
1233+
alphafold3 = Alphafold3(
1234+
dim_atom_inputs = 77,
1235+
dim_template_feats = 44
1236+
)
1237+
1238+
# mock inputs
1239+
1240+
seq_len = 16
1241+
molecule_atom_lens = torch.randint(1, 3, (2, seq_len))
1242+
atom_seq_len = molecule_atom_lens.sum(dim = -1).amax()
1243+
1244+
atom_inputs = torch.randn(2, atom_seq_len, 77)
1245+
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
1246+
1247+
additional_molecule_feats = torch.randint(0, 2, (2, seq_len, 5))
1248+
additional_token_feats = torch.randn(2, seq_len, 33)
1249+
is_molecule_types = torch.randint(0, 2, (2, seq_len, 5)).bool()
1250+
is_molecule_mod = torch.randint(0, 2, (2, seq_len, 4)).bool()
1251+
molecule_ids = torch.randint(0, 32, (2, seq_len))
1252+
1253+
template_feats = torch.randn(2, 2, seq_len, seq_len, 44)
1254+
template_mask = torch.ones((2, 2)).bool()
1255+
1256+
msa = torch.randn(2, 7, seq_len, 32)
1257+
msa_mask = torch.ones((2, 7)).bool()
1258+
1259+
additional_msa_feats = torch.randn(2, 7, seq_len, 2)
1260+
1261+
# required for training, but omitted on inference
1262+
1263+
atom_pos = torch.randn(2, atom_seq_len, 3)
1264+
1265+
molecule_atom_indices = molecule_atom_lens - 1 # last atom, as an example
1266+
molecule_atom_indices += (molecule_atom_lens.cumsum(dim = -1) - molecule_atom_lens)
1267+
1268+
distance_labels = torch.randint(0, 37, (2, seq_len, seq_len))
1269+
resolved_labels = torch.randint(0, 2, (2, atom_seq_len))
1270+
1271+
# train
1272+
1273+
loss = alphafold3(
1274+
num_recycling_steps = 2,
1275+
atom_inputs = atom_inputs,
1276+
atompair_inputs = atompair_inputs,
1277+
molecule_ids = molecule_ids,
1278+
molecule_atom_lens = molecule_atom_lens,
1279+
additional_molecule_feats = additional_molecule_feats,
1280+
additional_msa_feats = additional_msa_feats,
1281+
additional_token_feats = additional_token_feats,
1282+
is_molecule_types = is_molecule_types,
1283+
is_molecule_mod = is_molecule_mod,
1284+
msa = msa,
1285+
msa_mask = msa_mask,
1286+
templates = template_feats,
1287+
template_mask = template_mask,
1288+
atom_pos = atom_pos,
1289+
molecule_atom_indices = molecule_atom_indices,
1290+
distance_labels = distance_labels,
1291+
resolved_labels = resolved_labels
1292+
)
1293+
1294+
loss.backward()
1295+
1296+
# after much training ...
1297+
1298+
sampled_atom_pos = alphafold3(
1299+
num_recycling_steps = 4,
1300+
num_sample_steps = 16,
1301+
atom_inputs = atom_inputs,
1302+
atompair_inputs = atompair_inputs,
1303+
molecule_ids = molecule_ids,
1304+
molecule_atom_lens = molecule_atom_lens,
1305+
additional_molecule_feats = additional_molecule_feats,
1306+
additional_msa_feats = additional_msa_feats,
1307+
additional_token_feats = additional_token_feats,
1308+
is_molecule_types = is_molecule_types,
1309+
is_molecule_mod = is_molecule_mod,
1310+
msa = msa,
1311+
msa_mask = msa_mask,
1312+
templates = template_feats,
1313+
template_mask = template_mask
1314+
)
1315+
1316+
sampled_atom_pos.shape # (2, <atom_seqlen>, 3)
1317+
assert sampled_atom_pos.ndim == 3
1318+
1319+
def test_readme2():
1320+
contrived_protein = 'AG'
1321+
1322+
mock_atompos = [
1323+
torch.randn(5, 3), # alanine has 5 non-hydrogen atoms
1324+
torch.randn(4, 3) # glycine has 4 non-hydrogen atoms
1325+
]
1326+
1327+
train_alphafold3_input = Alphafold3Input(
1328+
proteins = [contrived_protein],
1329+
atom_pos = mock_atompos
1330+
)
1331+
1332+
eval_alphafold3_input = Alphafold3Input(
1333+
proteins = [contrived_protein]
1334+
)
1335+
1336+
batched_atom_input = alphafold3_inputs_to_batched_atom_input(train_alphafold3_input, atoms_per_window = 27)
1337+
1338+
# training
1339+
1340+
alphafold3 = Alphafold3(
1341+
dim_atom_inputs = 3,
1342+
dim_atompair_inputs = 5,
1343+
atoms_per_window = 27,
1344+
dim_template_feats = 44,
1345+
num_dist_bins = 38,
1346+
num_molecule_mods = 0,
1347+
confidence_head_kwargs = dict(
1348+
pairformer_depth = 1
1349+
),
1350+
template_embedder_kwargs = dict(
1351+
pairformer_stack_depth = 1
1352+
),
1353+
msa_module_kwargs = dict(
1354+
depth = 1
1355+
),
1356+
pairformer_stack = dict(
1357+
depth = 2
1358+
),
1359+
diffusion_module_kwargs = dict(
1360+
atom_encoder_depth = 1,
1361+
token_transformer_depth = 1,
1362+
atom_decoder_depth = 1,
1363+
)
1364+
)
1365+
1366+
loss = alphafold3(**batched_atom_input.model_forward_dict())
1367+
loss.backward()
1368+
1369+
# sampling
1370+
1371+
batched_eval_atom_input = alphafold3_inputs_to_batched_atom_input(eval_alphafold3_input, atoms_per_window = 27)
1372+
1373+
alphafold3.eval()
1374+
sampled_atom_pos = alphafold3(**batched_eval_atom_input.model_forward_dict())
1375+
1376+
assert sampled_atom_pos.shape == (1, (5 + 4), 3)

0 commit comments

Comments
 (0)