Skip to content

Commit 14e3ae7

Browse files
committed
first pass through molecule or atom level pae label derivation for training confidence head
1 parent b57414d commit 14e3ae7

File tree

5 files changed

+74
-31
lines changed

5 files changed

+74
-31
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5142,14 +5142,19 @@ def __init__(
51425142
num_dist_bins = num_dist_bins
51435143
)
51445144

5145-
# pae bins
5145+
# pae related bins and modules
51465146

51475147
pae_bins_tensor = Tensor(pae_bins)
51485148
self.register_buffer('pae_bins', pae_bins_tensor)
51495149
num_pae_bins = len(pae_bins)
51505150

5151+
self.rigid_from_three_points = RigidFrom3Points()
5152+
self.compute_alignment_error = ComputeAlignmentError()
5153+
51515154
# confidence head
51525155

5156+
self.confidence_head_atom_resolution = confidence_head_atom_resolution
5157+
51535158
self.confidence_head = ConfidenceHead(
51545159
dim_single_inputs = dim_single_inputs,
51555160
atompair_dist_bins = distance_bins,
@@ -5286,7 +5291,6 @@ def forward(
52865291
num_sample_steps: int | None = None,
52875292
atom_pos: Float['b m 3'] | None = None,
52885293
distance_labels: Int['b n n'] | Int['b m m'] | None = None,
5289-
pae_labels: Int['b n n'] | Int['b m m'] | None = None,
52905294
pde_labels: Int['b n n'] | Int['b m m'] | None = None,
52915295
plddt_labels: Int['b n'] | Int['b m'] | None = None,
52925296
resolved_labels: Int['b n'] | Int['b m'] | None = None,
@@ -5544,7 +5548,7 @@ def forward(
55445548

55455549
atom_pos_given = exists(atom_pos)
55465550

5547-
confidence_head_labels = (pae_labels, pde_labels, plddt_labels, resolved_labels)
5551+
confidence_head_labels = (atom_indices_for_frame, pde_labels, plddt_labels, resolved_labels)
55485552
all_labels = (distance_labels, *confidence_head_labels)
55495553

55505554
has_labels = any([*map(exists, all_labels)])
@@ -5592,14 +5596,17 @@ def forward(
55925596
mask = mask,
55935597
return_pae_logits = True
55945598
)
5595-
if return_distogram_head_logits:
5596-
distogram_head_logits = self.distogram_head(pairwise.clone().detach())
5597-
return (
5598-
sampled_atom_pos,
5599-
confidence_head_logits,
5600-
distogram_head_logits,
5601-
)
5602-
return sampled_atom_pos, confidence_head_logits
5599+
5600+
if not return_distogram_head_logits:
5601+
return sampled_atom_pos, confidence_head_logits
5602+
5603+
distogram_head_logits = self.distogram_head(pairwise.clone().detach())
5604+
5605+
return (
5606+
sampled_atom_pos,
5607+
confidence_head_logits,
5608+
distogram_head_logits,
5609+
)
56035610

56045611
# if being forced to return loss, but do not have sufficient information to return losses, just return 0
56055612

@@ -5621,6 +5628,8 @@ def forward(
56215628

56225629
# distogram head
56235630

5631+
molecule_pos = None
5632+
56245633
if not exists(distance_labels) and atom_pos_given and exists(distogram_atom_indices):
56255634
# molecule_pos = einx.get_at('b [m] c, b n -> b n c', atom_pos, distogram_atom_indices)
56265635

@@ -5668,8 +5677,10 @@ def forward(
56685677
additional_molecule_feats,
56695678
is_molecule_types,
56705679
molecule_atom_indices,
5680+
molecule_pos,
5681+
distogram_atom_indices,
5682+
atom_indices_for_frame,
56715683
molecule_atom_lens,
5672-
pae_labels,
56735684
pde_labels,
56745685
plddt_labels,
56755686
resolved_labels,
@@ -5692,8 +5703,10 @@ def forward(
56925703
additional_molecule_feats,
56935704
is_molecule_types,
56945705
molecule_atom_indices,
5706+
molecule_pos,
5707+
distogram_atom_indices,
5708+
atom_indices_for_frame,
56955709
molecule_atom_lens,
5696-
pae_labels,
56975710
pde_labels,
56985711
plddt_labels,
56995712
resolved_labels
@@ -5742,6 +5755,41 @@ def forward(
57425755
return_denoised_pos = True,
57435756
)
57445757

5758+
# determine pae labels if possible
5759+
5760+
pae_labels = None
5761+
ch_atom_res = self.confidence_head_atom_resolution
5762+
5763+
if atom_pos_given and exists(atom_indices_for_frame):
5764+
5765+
denoised_molecule_pos = None
5766+
5767+
if not ch_atom_res:
5768+
assert exists(molecule_pos), '`distogram_atom_indices` must be passed in for calculating non-atomic PAE labels'
5769+
denoised_molecule_pos = denoised_atom_pos.gather(1, distogram_atom_indices)
5770+
5771+
three_atoms = einx.get_at('b [m] c, b n three -> three b n c', atom_pos, atom_indices_for_frame)
5772+
pred_three_atoms = einx.get_at('b [m] c, b n three -> three b n c', denoised_atom_pos, atom_indices_for_frame)
5773+
5774+
# compute frames
5775+
5776+
frames, _ = self.rigid_from_three_points(three_atoms)
5777+
pred_frames, _ = self.rigid_from_three_points(pred_three_atoms)
5778+
5779+
# align error
5780+
5781+
align_error = self.compute_alignment_error(
5782+
denoised_atom_pos if ch_atom_res else denoised_molecule_pos,
5783+
atom_pos if ch_atom_res else molecule_pos,
5784+
pred_frames,
5785+
frames,
5786+
molecule_atom_lens = molecule_atom_lens
5787+
)
5788+
5789+
# calculate pae labels as alignment error binned to 64 (0 - 32A)
5790+
5791+
pae_labels = distance_to_bins(align_error, self.pae_bins)
5792+
57455793
# confidence head
57465794

57475795
should_call_confidence_head = any([*map(exists, confidence_head_labels)])

alphafold3_pytorch/inputs.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,8 @@ class AtomInput:
193193
missing_atom_mask: Bool[' m'] | None = None
194194
molecule_atom_indices: Int[' n'] | None = None
195195
distogram_atom_indices: Int[' n'] | None = None
196+
atom_indices_for_frame: Int['n 3'] | None = None
196197
distance_labels: Int['n n'] | None = None
197-
pae_labels: Int['n n'] | None = None
198198
pde_labels: Int['n n'] | None = None
199199
plddt_labels: Int[' n'] | None = None
200200
resolved_labels: Int[' n'] | None = None
@@ -227,8 +227,8 @@ class BatchedAtomInput:
227227
missing_atom_mask: Bool['b m'] | None = None
228228
molecule_atom_indices: Int['b n'] | None = None
229229
distogram_atom_indices: Int['b n'] | None = None
230+
atom_indices_for_frame: Int['b n 3'] | None = None
230231
distance_labels: Int['b n n'] | None = None
231-
pae_labels: Int['b n n'] | None = None
232232
pde_labels: Int['b n n'] | None = None
233233
plddt_labels: Int['b n'] | None = None
234234
resolved_labels: Int['b n'] | None = None
@@ -444,6 +444,7 @@ class MoleculeInput:
444444
is_molecule_mod: Bool['n num_mods'] | Bool[' n'] | None = None
445445
molecule_atom_indices: List[int | None] | None = None
446446
distogram_atom_indices: List[int | None] | None = None
447+
atom_indices_for_frame: List[Tuple[int, int, int] | None] | None = None
447448
missing_atom_indices: List[Int[' _'] | None] | None = None
448449
missing_token_indices: List[Int[' _'] | None] | None = None
449450
atom_parent_ids: Int[' m'] | None = None
@@ -454,7 +455,6 @@ class MoleculeInput:
454455
template_mask: Bool[' t'] | None = None
455456
msa_mask: Bool[' s'] | None = None
456457
distance_labels: Int['n n'] | None = None
457-
pae_labels: Int['n n'] | None = None
458458
pde_labels: Int[' n'] | None = None
459459
resolved_labels: Int[' n'] | None = None
460460
chains: Tuple[int | None, int | None] | None = (None, None)
@@ -715,6 +715,14 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
715715
if is_molecule_mod.ndim == 1:
716716
is_molecule_mod = rearrange(is_molecule_mod, 'n -> n 1')
717717

718+
# handle `atom_indices_for_frame` for the PAE
719+
720+
atom_indices_for_frame = i.atom_indices_for_frame
721+
722+
if exists(atom_indices_for_frame):
723+
atom_indices_for_frame = [default(indices, (-1, -1, -1)) for indices in i.atom_indices_for_frame]
724+
atom_indices_for_frame = tensor(atom_indices_for_frame)
725+
718726
# atom input
719727

720728
atom_input = AtomInput(
@@ -724,6 +732,7 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
724732
molecule_ids=i.molecule_ids,
725733
molecule_atom_indices=i.molecule_atom_indices,
726734
distogram_atom_indices=i.distogram_atom_indices,
735+
atom_indices_for_frame=atom_indices_for_frame,
727736
is_molecule_mod=is_molecule_mod,
728737
msa=i.msa,
729738
templates=i.templates,
@@ -773,7 +782,6 @@ class MoleculeLengthMoleculeInput:
773782
template_mask: Bool[' t'] | None = None
774783
msa_mask: Bool[' s'] | None = None
775784
distance_labels: Int['n n'] | None = None
776-
pae_labels: Int['n n'] | None = None
777785
pde_labels: Int[' n'] | None = None
778786
resolved_labels: Int[' n'] | None = None
779787
chains: Tuple[int | None, int | None] | None = (None, None)
@@ -1194,7 +1202,6 @@ class Alphafold3Input:
11941202
template_mask: Bool[' t'] | None = None
11951203
msa_mask: Bool[' s'] | None = None
11961204
distance_labels: Int['n n'] | None = None
1197-
pae_labels: Int['n n'] | None = None
11981205
pde_labels: Int[' n'] | None = None
11991206
resolved_labels: Int[' n'] | None = None
12001207
chains: Tuple[int | None, int | None] | None = (None, None)

alphafold3_pytorch/mocks.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ def __getitem__(self, idx):
7878
distogram_atom_indices = molecule_atom_lens - 1
7979

8080
distance_labels = torch.randint(0, 37, (seq_len, seq_len))
81-
pae_labels = torch.randint(0, 64, (seq_len, seq_len))
8281
pde_labels = torch.randint(0, 64, (seq_len, seq_len))
8382
plddt_labels = torch.randint(0, 50, (seq_len,))
8483
resolved_labels = torch.randint(0, 2, (seq_len,))
@@ -104,7 +103,6 @@ def __getitem__(self, idx):
104103
molecule_atom_indices = molecule_atom_indices,
105104
distogram_atom_indices = distogram_atom_indices,
106105
distance_labels = distance_labels,
107-
pae_labels = pae_labels,
108106
pde_labels = pde_labels,
109107
plddt_labels = plddt_labels,
110108
resolved_labels = resolved_labels,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.128"
3+
version = "0.3.0"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

tests/test_af3.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,6 @@ def test_alphafold3(
611611
molecule_atom_indices = molecule_atom_lens - 1
612612

613613
label_len = atom_seq_len if confidence_head_atom_resolution else seq_len
614-
pae_labels = torch.randint(0, 64, (2, label_len, label_len))
615614
pde_labels = torch.randint(0, 64, (2, label_len, label_len))
616615
plddt_labels = torch.randint(0, 50, (2, label_len))
617616
resolved_labels = torch.randint(0, 2, (2, label_len))
@@ -680,7 +679,6 @@ def test_alphafold3(
680679
atom_pos = atom_pos,
681680
distogram_atom_indices = distogram_atom_indices,
682681
molecule_atom_indices = molecule_atom_indices,
683-
pae_labels = pae_labels,
684682
pde_labels = pde_labels,
685683
plddt_labels = plddt_labels,
686684
resolved_labels = resolved_labels,
@@ -724,7 +722,6 @@ def test_alphafold3_without_msa_and_templates():
724722
distogram_atom_indices = molecule_atom_lens - 1
725723

726724
distance_labels = torch.randint(0, 38, (2, seq_len, seq_len))
727-
pae_labels = torch.randint(0, 64, (2, seq_len, seq_len))
728725
pde_labels = torch.randint(0, 64, (2, seq_len, seq_len))
729726
plddt_labels = torch.randint(0, 50, (2, seq_len))
730727
resolved_labels = torch.randint(0, 2, (2, seq_len))
@@ -777,7 +774,6 @@ def test_alphafold3_without_msa_and_templates():
777774
atom_pos = atom_pos,
778775
distogram_atom_indices = distogram_atom_indices,
779776
distance_labels = distance_labels,
780-
pae_labels = pae_labels,
781777
pde_labels = pde_labels,
782778
plddt_labels = plddt_labels,
783779
resolved_labels = resolved_labels,
@@ -803,7 +799,6 @@ def test_alphafold3_force_return_loss():
803799
molecule_atom_indices = molecule_atom_lens - 1
804800

805801
distance_labels = torch.randint(0, 38, (2, seq_len, seq_len))
806-
pae_labels = torch.randint(0, 64, (2, seq_len, seq_len))
807802
pde_labels = torch.randint(0, 64, (2, seq_len, seq_len))
808803
plddt_labels = torch.randint(0, 50, (2, seq_len))
809804
resolved_labels = torch.randint(0, 2, (2, seq_len))
@@ -845,7 +840,6 @@ def test_alphafold3_force_return_loss():
845840
distogram_atom_indices = distogram_atom_indices,
846841
molecule_atom_indices = molecule_atom_indices,
847842
distance_labels = distance_labels,
848-
pae_labels = pae_labels,
849843
pde_labels = pde_labels,
850844
plddt_labels = plddt_labels,
851845
resolved_labels = resolved_labels,
@@ -888,7 +882,6 @@ def test_alphafold3_force_return_loss_with_confidence_logits():
888882
molecule_atom_indices = molecule_atom_lens - 1
889883

890884
distance_labels = torch.randint(0, 38, (2, seq_len, seq_len))
891-
pae_labels = torch.randint(0, 64, (2, seq_len, seq_len))
892885
pde_labels = torch.randint(0, 64, (2, seq_len, seq_len))
893886
plddt_labels = torch.randint(0, 50, (2, seq_len))
894887
resolved_labels = torch.randint(0, 2, (2, seq_len))
@@ -930,7 +923,6 @@ def test_alphafold3_force_return_loss_with_confidence_logits():
930923
distogram_atom_indices = distogram_atom_indices,
931924
molecule_atom_indices = molecule_atom_indices,
932925
distance_labels = distance_labels,
933-
pae_labels = pae_labels,
934926
pde_labels = pde_labels,
935927
plddt_labels = plddt_labels,
936928
resolved_labels = resolved_labels,
@@ -996,7 +988,6 @@ def test_alphafold3_with_atom_and_bond_embeddings():
996988
molecule_atom_indices = molecule_atom_lens - 1
997989

998990
distance_labels = torch.randint(0, 37, (2, seq_len, seq_len))
999-
pae_labels = torch.randint(0, 64, (2, seq_len, seq_len))
1000991
pde_labels = torch.randint(0, 64, (2, seq_len, seq_len))
1001992
plddt_labels = torch.randint(0, 50, (2, seq_len))
1002993
resolved_labels = torch.randint(0, 2, (2, seq_len))
@@ -1022,7 +1013,6 @@ def test_alphafold3_with_atom_and_bond_embeddings():
10221013
distogram_atom_indices = distogram_atom_indices,
10231014
molecule_atom_indices = molecule_atom_indices,
10241015
distance_labels = distance_labels,
1025-
pae_labels = pae_labels,
10261016
pde_labels = pde_labels,
10271017
plddt_labels = plddt_labels,
10281018
resolved_labels = resolved_labels

0 commit comments

Comments
 (0)