Skip to content

Commit 9907a3d

Browse files
authored
add end to end test for model selection (#177)
1 parent 9e6c7c9 commit 9907a3d

File tree

3 files changed

+166
-29
lines changed

3 files changed

+166
-29
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 97 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,6 @@
157157

158158
# constants
159159

160-
SCORED_SAMPLE = Tuple[int, Float["b m 3"], Float[" b"], Float[" b"]] # type: ignore
161-
162160
LinearNoBias = partial(Linear, bias = False)
163161

164162
# always use non reentrant checkpointing
@@ -3470,6 +3468,13 @@ class ConfidenceHeadLogits(NamedTuple):
34703468
plddt: Float['b plddt m']
34713469
resolved: Float['b 2 m']
34723470

3471+
class Alphafold3Logits(NamedTuple):
3472+
pae: Float['b pae n n'] | None
3473+
pde: Float['b pde n n']
3474+
plddt: Float['b plddt m']
3475+
resolved: Float['b 2 m']
3476+
distance: Float['b dist m m'] | Float['b dist n n'] | None
3477+
34733478
class ConfidenceHead(Module):
34743479
""" Algorithm 31 """
34753480

@@ -4281,6 +4286,13 @@ def _protein_structure_from_feature(
42814286

42824287
return builder.get_structure()
42834288

4289+
ScoredSample = Tuple[int, Float["b m 3"], Float[" b"], Float[" b"]] # type: ignore
4290+
4291+
class ScoreDetails(NamedTuple):
4292+
best_gpde_index: int
4293+
best_lddt_index: int
4294+
score: Float[' b']
4295+
scored_samples: ScoredSample
42844296

42854297
class ComputeModelSelectionScore(Module):
42864298
"""Compute model selection score."""
@@ -4787,15 +4799,20 @@ def compute_unresolved_rasa(
47874799
def compute_model_selection_score(
47884800
self,
47894801
batch: BatchedAtomInput,
4790-
samples: List[Tuple[Float["b m 3"], Float["b pde n n"], Float["b dist n n"]]],
4802+
samples: List[Tuple[
4803+
Float["b m 3"],
4804+
Float["b pde n n"],
4805+
Float["b dist n n"]
4806+
]],
47914807
is_fine_tuning: bool = None,
4792-
return_top_model: bool = False,
4808+
return_details: bool = False,
47934809
return_unweighted_scores: bool = False,
47944810
compute_rasa: bool = False,
47954811
unresolved_cid: List[int] | None = None,
47964812
unresolved_residue_mask: Bool["b n"] | None = None,
47974813
missing_chain_index: int = -1,
4798-
) -> Float[" b"] | Tuple[Float[" b"], SCORED_SAMPLE]:
4814+
) -> Float[" b"] | ScoreDetails:
4815+
47994816
"""Compute the model selection score for an input batch and corresponding (sampled) atom
48004817
positions.
48014818
@@ -4841,7 +4858,7 @@ def compute_model_selection_score(
48414858

48424859
# score samples
48434860

4844-
scored_samples: List[SCORED_SAMPLE] = []
4861+
scored_samples: List[ScoredSample] = []
48454862

48464863
for sample_idx, sample in enumerate(samples):
48474864
atom_pos_pred, pde_logits, dist_logits = sample
@@ -4871,19 +4888,73 @@ def compute_model_selection_score(
48714888

48724889
scored_samples.append((sample_idx, atom_pos_pred, weighted_lddt, gpde))
48734890

4874-
top_ranked_sample = max(
4875-
scored_samples, key=lambda x: x[-1].mean()
4876-
) # rank by batch-averaged gPDE
4877-
best_of_5_sample = max(
4878-
scored_samples, key=lambda x: x[-2].mean()
4879-
) # rank by batch-averaged lDDT
4891+
# quick collate
4892+
4893+
*_, all_weighted_lddt, all_gpde = zip(*scored_samples)
4894+
4895+
# rank by batch-averaged gPDE
48804896

4881-
model_selection_score = (top_ranked_sample[-2] + best_of_5_sample[-2]) / 2
4897+
best_gpde_index = torch.stack(all_gpde).mean(dim = -1).topk(1).indices.item()
48824898

4883-
if return_top_model:
4884-
return model_selection_score, top_ranked_sample
4899+
# rank by batch-averaged lDDT
48854900

4886-
return model_selection_score
4901+
best_lddt_index = torch.stack(all_weighted_lddt).mean(dim = -1).topk(1).indices.item()
4902+
4903+
# some weighted score
4904+
4905+
model_selection_score = (
4906+
scored_samples[best_gpde_index][-2] +
4907+
scored_samples[best_lddt_index][-2]
4908+
) / 2
4909+
4910+
if not return_details:
4911+
return model_selection_score
4912+
4913+
score_details = ScoreDetails(
4914+
best_gpde_index = best_gpde_index,
4915+
best_lddt_index = best_lddt_index,
4916+
score = model_selection_score,
4917+
scored_samples = scored_samples
4918+
)
4919+
4920+
return score_details
4921+
4922+
@typecheck
4923+
def forward(
4924+
self,
4925+
alphafolds: Tuple[Alphafold3],
4926+
batched_atom_inputs: BatchedAtomInput,
4927+
**kwargs
4928+
) -> Float[" b"] | ScoreDetails:
4929+
4930+
"""
4931+
give this a tuple of all the Alphafolds and a batch of atomic inputs
4932+
it will select the best one by the model selection score by returning the index of the Tuple
4933+
"""
4934+
4935+
samples = []
4936+
4937+
with torch.no_grad():
4938+
for alphafold in alphafolds:
4939+
alphafold.eval()
4940+
4941+
pred_atom_pos, logits = alphafold(
4942+
**batched_atom_inputs.model_forward_dict(),
4943+
return_loss = False,
4944+
return_confidence_head_logits = True,
4945+
return_distogram_head_logits = True
4946+
)
4947+
4948+
samples.append((pred_atom_pos, logits.pde, logits.distance))
4949+
4950+
4951+
scores = self.compute_model_selection_score(
4952+
batched_atom_inputs,
4953+
samples = samples,
4954+
**kwargs
4955+
)
4956+
4957+
return scores
48874958

48884959
# main class
48894960

@@ -5383,8 +5454,7 @@ def forward(
53835454
hard_validate: bool = False
53845455
) -> (
53855456
Float['b m 3'] |
5386-
Tuple[Float['b m 3'], ConfidenceHeadLogits] |
5387-
Tuple[Float['b m 3'], ConfidenceHeadLogits, Float['b l n n'] | Float['b l m m']] |
5457+
Tuple[Float['b m 3'], ConfidenceHeadLogits | Alphafold3Logits] |
53885458
Float[''] |
53895459
Tuple[Float[''], LossBreakdown]
53905460
):
@@ -5673,16 +5743,17 @@ def forward(
56735743
return_pae_logits = True
56745744
)
56755745

5676-
if not return_distogram_head_logits:
5677-
return sampled_atom_pos, confidence_head_logits
5746+
returned_logits = confidence_head_logits
56785747

5679-
distogram_head_logits = self.distogram_head(pairwise.clone().detach())
5748+
if return_distogram_head_logits:
5749+
distogram_head_logits = self.distogram_head(pairwise.clone().detach())
56805750

5681-
return (
5682-
sampled_atom_pos,
5683-
confidence_head_logits,
5684-
distogram_head_logits,
5685-
)
5751+
returned_logits = Alphafold3Logits(
5752+
**confidence_head_logits._asdict(),
5753+
distance = distogram_head_logits
5754+
)
5755+
5756+
return sampled_atom_pos, returned_logits
56865757

56875758
# if being forced to return loss, but do not have sufficient information to return losses, just return 0
56885759

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.3.18"
3+
version = "0.4.0"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

tests/test_af3.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
collate_inputs_to_batched_atom_input
4040
)
4141

42+
from alphafold3_pytorch.mocks import MockAtomDataset
43+
4244
from alphafold3_pytorch.configs import (
4345
Alphafold3Config,
4446
create_alphafold3_from_yaml
@@ -61,7 +63,7 @@
6163
PDBInput,
6264
PDBDataset,
6365
default_extract_atom_feats_fn,
64-
default_extract_atompair_feats_fn
66+
default_extract_atompair_feats_fn,
6567
)
6668

6769
from alphafold3_pytorch.utils.model_utils import exclusive_cumsum
@@ -1098,7 +1100,6 @@ def test_model_selection_score():
10981100
for chain_len in chain_length
10991101
]).long()
11001102

1101-
11021103
is_molecule_types = torch.zeros_like(asym_id)
11031104
is_molecule_types = torch.nn.functional.one_hot(is_molecule_types, 5).bool()
11041105

@@ -1122,6 +1123,71 @@ def test_model_selection_score():
11221123
is_fine_tuning=False
11231124
)
11241125

1126+
def test_model_selection_score_end_to_end():
1127+
1128+
# prepare two atom inputs for evaluating model selection
1129+
1130+
mock_atom_dataset = MockAtomDataset(10)
1131+
1132+
atom_inputs = [mock_atom_dataset[0], mock_atom_dataset[1]]
1133+
batched_atom_input = collate_inputs_to_batched_atom_input(atom_inputs, atoms_per_window=27)
1134+
1135+
# two models to be selected
1136+
1137+
alphafold3_kwargs = dict(
1138+
dim_atom_inputs = 77,
1139+
dim_pairwise = 8,
1140+
dim_single = 8,
1141+
dim_token = 8,
1142+
atoms_per_window = 27,
1143+
dim_template_feats = 44,
1144+
num_dist_bins = 38,
1145+
confidence_head_kwargs = dict(
1146+
pairformer_depth = 1
1147+
),
1148+
template_embedder_kwargs = dict(
1149+
pairformer_stack_depth = 1
1150+
),
1151+
msa_module_kwargs = dict(
1152+
depth = 1,
1153+
dim_msa = 8,
1154+
),
1155+
pairformer_stack=dict(
1156+
depth=1,
1157+
pair_bias_attn_dim_head = 4,
1158+
pair_bias_attn_heads = 2,
1159+
),
1160+
diffusion_module_kwargs=dict(
1161+
atom_encoder_depth=1,
1162+
token_transformer_depth=1,
1163+
atom_decoder_depth=1,
1164+
atom_decoder_kwargs = dict(
1165+
attn_pair_bias_kwargs = dict(
1166+
dim_head = 4
1167+
)
1168+
),
1169+
atom_encoder_kwargs = dict(
1170+
attn_pair_bias_kwargs = dict(
1171+
dim_head = 4
1172+
)
1173+
)
1174+
),
1175+
)
1176+
1177+
alphafold3_one = Alphafold3(**alphafold3_kwargs)
1178+
alphafold3_two = Alphafold3(**alphafold3_kwargs)
1179+
1180+
alphafolds = (alphafold3_one, alphafold3_two)
1181+
1182+
# evaluate
1183+
1184+
compute_model_selection_score = ComputeModelSelectionScore()
1185+
1186+
details = compute_model_selection_score(alphafolds, batched_atom_input, return_details = True)
1187+
1188+
best_alphafold_by_lddt = alphafolds[details.best_lddt_index]
1189+
assert isinstance(best_alphafold_by_lddt, Alphafold3)
1190+
11251191
def test_unresolved_protein_rasa():
11261192

11271193
# rest of the test

0 commit comments

Comments
 (0)