Skip to content

Commit 42ad45a

Browse files
committed
make sure one can return sampled atom positions along with confidence head logits
1 parent b264bc9 commit 42ad45a

File tree

3 files changed

+119
-3
lines changed

3 files changed

+119
-3
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3386,9 +3386,16 @@ def forward(
33863386
return_loss_breakdown = False,
33873387
return_loss: bool = None,
33883388
return_present_sampled_atoms: bool = False,
3389+
return_confidence_head_logits: bool = False,
33893390
num_rollout_steps: int = 20,
33903391
rollout_show_tqdm_pbar: bool = False
3391-
) -> Float['b m 3'] | Float['l 3'] | Float[''] | Tuple[Float[''], LossBreakdown]:
3392+
) -> (
3393+
Float['b m 3'] |
3394+
Float['l 3'] |
3395+
Tuple[Float['b m 3'] | Float['l 3'], ConfidenceHeadLogits] |
3396+
Float[''] |
3397+
Tuple[Float[''], LossBreakdown]
3398+
):
33923399

33933400
atom_seq_len = atom_inputs.shape[-2]
33943401

@@ -3635,7 +3642,25 @@ def forward(
36353642
if exists(missing_atom_mask) and return_present_sampled_atoms:
36363643
sampled_atom_pos = sampled_atom_pos[~missing_atom_mask]
36373644

3638-
return sampled_atom_pos
3645+
if not return_confidence_head_logits:
3646+
return sampled_atom_pos
3647+
3648+
# todo - handle missing atoms
3649+
3650+
assert exists(molecule_atom_indices)
3651+
3652+
pred_atom_pos = einx.get_at('b [m] c, b n -> b n c', sampled_atom_pos, molecule_atom_indices)
3653+
3654+
logits = self.confidence_head(
3655+
single_repr = single.detach(),
3656+
single_inputs_repr = single_inputs.detach(),
3657+
pairwise_repr = pairwise.detach(),
3658+
pred_atom_pos = pred_atom_pos.detach(),
3659+
mask = mask,
3660+
return_pae_logits = True
3661+
)
3662+
3663+
return sampled_atom_pos, logits
36393664

36403665
# if being forced to return loss, but do not have sufficient information to return losses, just return 0
36413666

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.8"
3+
version = "0.2.9"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ def test_alphafold3(
472472

473473
atom_pos = torch.randn(2, atom_seq_len, 3)
474474
distogram_atom_indices = molecule_atom_lens - 1
475+
molecule_atom_indices = molecule_atom_lens - 1
475476

476477
pae_labels = torch.randint(0, 64, (2, seq_len, seq_len))
477478
pde_labels = torch.randint(0, 64, (2, seq_len, seq_len))
@@ -524,6 +525,7 @@ def test_alphafold3(
524525
template_mask = template_mask,
525526
atom_pos = atom_pos,
526527
distogram_atom_indices = distogram_atom_indices,
528+
molecule_atom_indices = molecule_atom_indices,
527529
pae_labels = pae_labels,
528530
pde_labels = pde_labels,
529531
plddt_labels = plddt_labels,
@@ -630,6 +632,7 @@ def test_alphafold3_force_return_loss():
630632

631633
atom_pos = torch.randn(2, atom_seq_len, 3)
632634
distogram_atom_indices = molecule_atom_lens - 1
635+
molecule_atom_indices = molecule_atom_lens - 1
633636

634637
distance_labels = torch.randint(0, 38, (2, seq_len, seq_len))
635638
pae_labels = torch.randint(0, 64, (2, seq_len, seq_len))
@@ -671,6 +674,7 @@ def test_alphafold3_force_return_loss():
671674
additional_token_feats = additional_token_feats,
672675
atom_pos = atom_pos,
673676
distogram_atom_indices = distogram_atom_indices,
677+
molecule_atom_indices = molecule_atom_indices,
674678
distance_labels = distance_labels,
675679
pae_labels = pae_labels,
676680
pde_labels = pde_labels,
@@ -682,6 +686,91 @@ def test_alphafold3_force_return_loss():
682686

683687
assert sampled_atom_pos.ndim == 3
684688

689+
loss, _ = alphafold3(
690+
num_recycling_steps = 2,
691+
atom_inputs = atom_inputs,
692+
molecule_ids = molecule_ids,
693+
molecule_atom_lens = molecule_atom_lens,
694+
atompair_inputs = atompair_inputs,
695+
is_molecule_types = is_molecule_types,
696+
additional_molecule_feats = additional_molecule_feats,
697+
additional_token_feats = additional_token_feats,
698+
molecule_atom_indices = molecule_atom_indices,
699+
return_loss_breakdown = True,
700+
return_loss = True # force returning loss even if no labels given
701+
)
702+
703+
assert loss == 0.
704+
705+
def test_alphafold3_force_return_loss_with_confidence_logits():
706+
seq_len = 16
707+
molecule_atom_lens = torch.randint(1, 3, (2, seq_len))
708+
atom_seq_len = molecule_atom_lens.sum(dim = -1).amax()
709+
710+
atom_inputs = torch.randn(2, atom_seq_len, 77)
711+
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
712+
additional_molecule_feats = torch.randint(0, 2, (2, seq_len, 5))
713+
additional_token_feats = torch.randn(2, seq_len, 2)
714+
is_molecule_types = torch.randint(0, 2, (2, seq_len, IS_MOLECULE_TYPES)).bool()
715+
molecule_ids = torch.randint(0, 32, (2, seq_len))
716+
717+
atom_pos = torch.randn(2, atom_seq_len, 3)
718+
distogram_atom_indices = molecule_atom_lens - 1
719+
molecule_atom_indices = molecule_atom_lens - 1
720+
721+
distance_labels = torch.randint(0, 38, (2, seq_len, seq_len))
722+
pae_labels = torch.randint(0, 64, (2, seq_len, seq_len))
723+
pde_labels = torch.randint(0, 64, (2, seq_len, seq_len))
724+
plddt_labels = torch.randint(0, 50, (2, seq_len))
725+
resolved_labels = torch.randint(0, 2, (2, seq_len))
726+
727+
alphafold3 = Alphafold3(
728+
dim_atom_inputs = 77,
729+
dim_template_feats = 44,
730+
num_dist_bins = 38,
731+
confidence_head_kwargs = dict(
732+
pairformer_depth = 1
733+
),
734+
template_embedder_kwargs = dict(
735+
pairformer_stack_depth = 1
736+
),
737+
msa_module_kwargs = dict(
738+
depth = 1
739+
),
740+
pairformer_stack = dict(
741+
depth = 2
742+
),
743+
diffusion_module_kwargs = dict(
744+
atom_encoder_depth = 1,
745+
token_transformer_depth = 1,
746+
atom_decoder_depth = 1,
747+
),
748+
)
749+
750+
sampled_atom_pos, confidence_head_logits = alphafold3(
751+
num_recycling_steps = 2,
752+
atom_inputs = atom_inputs,
753+
molecule_ids = molecule_ids,
754+
molecule_atom_lens = molecule_atom_lens,
755+
atompair_inputs = atompair_inputs,
756+
is_molecule_types = is_molecule_types,
757+
additional_molecule_feats = additional_molecule_feats,
758+
additional_token_feats = additional_token_feats,
759+
atom_pos = atom_pos,
760+
distogram_atom_indices = distogram_atom_indices,
761+
molecule_atom_indices = molecule_atom_indices,
762+
distance_labels = distance_labels,
763+
pae_labels = pae_labels,
764+
pde_labels = pde_labels,
765+
plddt_labels = plddt_labels,
766+
resolved_labels = resolved_labels,
767+
return_loss_breakdown = True,
768+
return_loss = False, # force sampling even if labels are given
769+
return_confidence_head_logits = True
770+
)
771+
772+
assert sampled_atom_pos.ndim == 3
773+
685774
loss, _ = alphafold3(
686775
num_recycling_steps = 2,
687776
atom_inputs = atom_inputs,
@@ -733,6 +822,7 @@ def test_alphafold3_with_atom_and_bond_embeddings():
733822

734823
atom_pos = torch.randn(2, atom_seq_len, 3)
735824
distogram_atom_indices = molecule_atom_lens - 1 # last atom, as an example
825+
molecule_atom_indices = molecule_atom_lens - 1
736826

737827
distance_labels = torch.randint(0, 37, (2, seq_len, seq_len))
738828
pae_labels = torch.randint(0, 64, (2, seq_len, seq_len))
@@ -759,6 +849,7 @@ def test_alphafold3_with_atom_and_bond_embeddings():
759849
template_mask = template_mask,
760850
atom_pos = atom_pos,
761851
distogram_atom_indices = distogram_atom_indices,
852+
molecule_atom_indices = molecule_atom_indices,
762853
distance_labels = distance_labels,
763854
pae_labels = pae_labels,
764855
pde_labels = pde_labels,

0 commit comments

Comments
 (0)