Skip to content

Commit c9eafae

Browse files
authored
quick fix (#133)
* hack to fix pytest-split-tests * try again * missing comma * some lint * some cleanup * skip rasa test if dssp not installed * another fix * attempt to install dssp in test env * sudo * skip rasa test for now
1 parent 68030fc commit c9eafae

File tree

6 files changed

+31
-33
lines changed

6 files changed

+31
-33
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ A fork with full Lightning + Hydra support is being maintained by <a href="https
4242

4343
- <a href="https://github.com/sj900">sj900</a> for integrating and testing the `WeightedPDBSampler` within the `PDBDataset`!
4444

45-
- <a href="https://github.com/xluo233">@xluo233</a> again for contributing the logic for computing the model selection score!
45+
- <a href="https://github.com/xluo233">@xluo233</a> again for contributing the logic for computing the model selection score as well as the unresolved rasa!
4646

4747
- <a href="https://github.com/wufandi">Fandi</a> for discovering a few inconsistencies in the elucidated atom diffusion module with the supplementary
4848

alphafold3_pytorch/alphafold3.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1995,11 +1995,9 @@ def forward(
19951995
mask: Bool['b n'] | None = None,
19961996
windowed_mask: Bool['b nw w (w*2)'] | None = None
19971997
):
1998-
w = self.attn_window_size
1998+
w, serial = self.attn_window_size, self.serial
19991999
has_windows = exists(w)
20002000

2001-
serial = self.serial
2002-
20032001
# handle windowing
20042002

20052003
pairwise_is_windowed = pairwise_repr.ndim == 5
@@ -2022,9 +2020,9 @@ def forward(
20222020

20232021
# main transformer
20242022

2025-
if self.serial and should_checkpoint(self, (noised_repr, single_repr, pairwise_repr)):
2023+
if serial and should_checkpoint(self, (noised_repr, single_repr, pairwise_repr)):
20262024
to_layers_fn = self.to_checkpointed_serial_layers
2027-
elif self.serial:
2025+
elif serial:
20282026
to_layers_fn = self.to_serial_layers
20292027
else:
20302028
to_layers_fn = self.to_parallel_layers
@@ -4067,7 +4065,7 @@ def __init__(
40674065
other_cutoff: float = 15.0,
40684066
contact_mask_threshold: float = 8.0,
40694067
is_fine_tuning: bool = False,
4070-
weight_dict_config: dict = None
4068+
weight_dict_config: dict = None,
40714069
dssp_path: str = 'mkdssp',
40724070
):
40734071

@@ -4105,7 +4103,7 @@ def compute_gpde(
41054103
dist_probs = F.softmax(dist_logits, dim=-1)
41064104

41074105
# for distances greater than the last breaks
4108-
dist_breaks = F.pdb(dist_breaks, (0, 1), value = 1e6)
4106+
dist_breaks = F.pad(dist_breaks, (0, 1), value = 1e6)
41094107
contact_mask = dist_breaks < self.contact_mask_threshold
41104108

41114109
contact_prob = einx.where(
@@ -4389,7 +4387,7 @@ def compute_unresolved_rasa(
43894387
molecule_atom_lens: Int['b n'],
43904388
atom_pos: Float['b m 3'],
43914389
atom_mask: Bool['b m'],
4392-
) -> Float['b']:
4390+
) -> Float[' b']:
43934391

43944392
unresolved_rasa = [self._compute_unresolved_rasa(*args) for args in
43954393
zip(unresolved_cid, unresolved_residue_mask, asym_id, molecule_ids, molecule_atom_lens, atom_pos, atom_mask)]

alphafold3_pytorch/inputs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def __getitem__(self, idx: int) -> AtomInput:
316316
@typecheck
317317
def atom_ref_pos_to_atompair_inputs(
318318
atom_ref_pos: Float['m 3'],
319-
atom_ref_space_uid: Int['m'] | None = None,
319+
atom_ref_space_uid: Int[' m'] | None = None,
320320
) -> Float['m m 5']:
321321

322322
# Algorithm 5 - lines 2-6
@@ -691,7 +691,7 @@ class MoleculeLengthMoleculeInput:
691691
src_tgt_atom_indices: Int['n 2']
692692
token_bonds: Bool['n n'] | None = None
693693
one_token_per_atom: List[bool] | None = None
694-
is_molecule_mod: Bool['n num_mods'] | Bool['n'] | None = None
694+
is_molecule_mod: Bool['n num_mods'] | Bool[' n'] | None = None
695695
molecule_atom_indices: List[int | None] | None = None
696696
distogram_atom_indices: List[int | None] | None = None
697697
missing_atom_indices: List[Int[' _'] | None] | None = None

alphafold3_pytorch/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
import torch
3737
from torch import Tensor
38+
from torch.nn import Module
3839
from torch.optim import Adam, Optimizer
3940
from torch.nn.utils.rnn import pad_sequence
4041
from torch.utils.data import Sampler, Dataset, DataLoader as OrigDataLoader

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.91"
3+
version = "0.2.92"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
import pytest
55
import random
66
import itertools
7+
import subprocess
78
from pathlib import Path
89

910
import torch
1011

12+
from collections import namedtuple
13+
1114
from alphafold3_pytorch import (
1215
SmoothLDDTLoss,
1316
WeightedRigidAlign,
@@ -30,6 +33,7 @@
3033
ConfidenceHeadLogits,
3134
ComputeModelSelectionScore,
3235
ComputeModelSelectionScore,
36+
collate_inputs_to_batched_atom_input
3337
)
3438

3539
from alphafold3_pytorch.configs import (
@@ -40,12 +44,20 @@
4044
from alphafold3_pytorch.alphafold3 import (
4145
mean_pool_with_lens,
4246
repeat_consecutive_with_lens,
43-
full_pairwise_repr_to_windowed
47+
full_pairwise_repr_to_windowed,
48+
get_cid_molecule_type,
4449
)
4550

4651
from alphafold3_pytorch.inputs import (
4752
IS_MOLECULE_TYPES,
48-
atom_ref_pos_to_atompair_inputs
53+
IS_PROTEIN,
54+
atom_ref_pos_to_atompair_inputs,
55+
molecule_to_atom_input,
56+
pdb_input_to_molecule_input,
57+
PDBInput,
58+
PDBDataset,
59+
default_extract_atom_feats_fn,
60+
default_extract_atompair_feats_fn
4961
)
5062

5163
def test_atom_ref_pos_to_atompair_inputs():
@@ -1090,27 +1102,14 @@ def test_model_selection_score():
10901102

10911103
def test_unresolved_protein_rasa():
10921104

1093-
from collections import namedtuple
1105+
# skip the test if dssp not installed
10941106

1095-
from alphafold3_pytorch.inputs import (
1096-
IS_MOLECULE_TYPES,
1097-
PDBInput,
1098-
default_extract_atom_feats_fn,
1099-
default_extract_atompair_feats_fn
1100-
1101-
)
1107+
try:
1108+
subprocess.check_output(["which", "mkdssp"])
1109+
except:
1110+
pytest.skip("mkdssp not found, test_unresolved_protein_rasa skipped")
11021111

1103-
from alphafold3_pytorch.inputs import (
1104-
PDBDataset,
1105-
molecule_to_atom_input,
1106-
pdb_input_to_molecule_input,
1107-
IS_PROTEIN,
1108-
)
1109-
1110-
from alphafold3_pytorch import collate_inputs_to_batched_atom_input
1111-
from alphafold3_pytorch.alphafold3 import (
1112-
get_cid_molecule_type,
1113-
)
1112+
# rest of the test
11141113

11151114
mmcif_filepath = os.path.join('data', 'test', '7a4d-assembly1.cif')
11161115
pdb_input = PDBInput(mmcif_filepath)

0 commit comments

Comments
 (0)