Skip to content

Commit 6ca487d

Browse files
authored
Add templates and atom-wise ligand frames for confidence head outputs (#195)
* Update README.md * Update __init__.py * Update data_pipeline.py * Create kalign.py * Update life.py * Update mocks.py * Create template_parsing.py * Update alphafold3.py * Update data_utils.py * Update inputs.py * Update model_utils.py * Update test_af3.py * Update test_dataloading.py * Update test_input.py * Create test_template_loading.py * Update inputs.py * Update template_parsing.py * Update test_af3.py * Update test_input.py * Update test_trainer.py * Update alphafold3.yaml * Update trainer.yaml * Update trainer_with_atom_dataset.yaml * Update trainer_with_atom_dataset_created_from_pdb.yaml * Update trainer_with_pdb_dataset.yaml * Update trainer_with_pdb_dataset_and_weighted_sampling.yaml * Update training.yaml * Update training_with_pdb_dataset.yaml * Update test_input.py * Update inputs.py
1 parent df8b26c commit 6ca487d

24 files changed

+1978
-679
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ A visualization of the molecules of life used in the repository can be seen and
4242

4343
- <a href="https://github.com/xluo233">@xluo233</a> for contributing the confidence measures, clash penalty ranking, and sample ranking logic!
4444

45-
- <a href="https://github.com/sj900">sj900</a> for integrating and testing the `WeightedPDBSampler` within the `PDBDataset` and for adding initial support for MSA parsing!
45+
- <a href="https://github.com/sj900">sj900</a> for integrating and testing the `WeightedPDBSampler` within the `PDBDataset` and for adding initial support for MSA and template parsing!
4646

4747
- <a href="https://github.com/xluo233">@xluo233</a> again for contributing the logic for computing the model selection score as well as the unresolved rasa!
4848

@@ -69,7 +69,7 @@ from alphafold3_pytorch.utils.model_utils import exclusive_cumsum
6969

7070
alphafold3 = Alphafold3(
7171
dim_atom_inputs = 77,
72-
dim_template_feats = 44
72+
dim_template_feats = 108
7373
)
7474

7575
# mock inputs
@@ -91,7 +91,7 @@ is_molecule_types = torch.randint(0, 2, (2, seq_len, 5)).bool()
9191
is_molecule_mod = torch.randint(0, 2, (2, seq_len, 4)).bool()
9292
molecule_ids = torch.randint(0, 32, (2, seq_len))
9393

94-
template_feats = torch.randn(2, 2, seq_len, seq_len, 44)
94+
template_feats = torch.randn(2, 2, seq_len, seq_len, 108)
9595
template_mask = torch.ones((2, 2)).bool()
9696

9797
msa = torch.randn(2, 7, seq_len, 32)
@@ -197,7 +197,7 @@ alphafold3 = Alphafold3(
197197
dim_atom_inputs = 3,
198198
dim_atompair_inputs = 5,
199199
atoms_per_window = 27,
200-
dim_template_feats = 44,
200+
dim_template_feats = 108,
201201
num_dist_bins = 38,
202202
num_molecule_mods = 0,
203203
confidence_head_kwargs = dict(

alphafold3_pytorch/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
SmoothLDDTLoss,
1010
WeightedRigidAlign,
1111
MultiChainPermutationAlignment,
12-
ExpressCoordinatesInFrame,
13-
RigidFrom3Points,
1412
ComputeAlignmentError,
1513
CentreRandomAugmentation,
1614
TemplateEmbedder,
@@ -70,6 +68,10 @@
7068
create_trainer_from_yaml,
7169
create_trainer_from_conductor_yaml
7270
)
71+
from alphafold3_pytorch.utils.model_utils import (
72+
ExpressCoordinatesInFrame,
73+
RigidFrom3Points,
74+
)
7375

7476
__all__ = [
7577
Attention,

alphafold3_pytorch/alphafold3.py

Lines changed: 15 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@
7070
)
7171

7272
from alphafold3_pytorch.utils.model_utils import (
73+
ExpressCoordinatesInFrame,
74+
RigidFrom3Points,
7375
calculate_weighted_rigid_align_weights,
7476
)
7577

@@ -3899,93 +3901,6 @@ def forward(
38993901
permuted_true_coords = labels["true_coords"].detach()
39003902
return permuted_true_coords
39013903

3902-
class ExpressCoordinatesInFrame(Module):
3903-
""" Algorithm 29 """
3904-
3905-
def __init__(
3906-
self,
3907-
eps = 1e-8
3908-
):
3909-
super().__init__()
3910-
self.eps = eps
3911-
3912-
@typecheck
3913-
def forward(
3914-
self,
3915-
coords: Float['b m 3'],
3916-
frame: Float['b m 3 3'] | Float['b 3 3'] | Float['3 3']
3917-
) -> Float['b m 3']:
3918-
"""
3919-
coords: coordinates to be expressed in the given frame
3920-
frame: frame defined by three points
3921-
"""
3922-
3923-
if frame.ndim == 2:
3924-
frame = rearrange(frame, 'fr fc -> 1 1 fr fc')
3925-
elif frame.ndim == 3:
3926-
frame = rearrange(frame, 'b fr fc -> b 1 fr fc')
3927-
3928-
# Extract frame atoms
3929-
a, b, c = frame.unbind(dim=-1)
3930-
w1 = l2norm(a - b, eps=self.eps)
3931-
w2 = l2norm(c - b, eps=self.eps)
3932-
3933-
# Build orthonormal basis
3934-
e1 = l2norm(w1 + w2, eps=self.eps)
3935-
e2 = l2norm(w2 - w1, eps=self.eps)
3936-
e3 = torch.cross(e1, e2, dim=-1)
3937-
3938-
# Project onto frame basis
3939-
d = coords - b
3940-
3941-
transformed_coords = torch.stack((
3942-
einsum(d, e1, '... i, ... i -> ...'),
3943-
einsum(d, e2, '... i, ... i -> ...'),
3944-
einsum(d, e3, '... i, ... i -> ...'),
3945-
), dim=-1)
3946-
3947-
return transformed_coords
3948-
3949-
class RigidFrom3Points(Module):
3950-
"""
3951-
Algorithm 21 in Section 1.8.1 in Alphafold2 paper
3952-
https://www.nature.com/articles/s41586-021-03819-2
3953-
"""
3954-
3955-
@typecheck
3956-
def forward(
3957-
self,
3958-
three_points: Tuple[Float['... 3'], Float['... 3'], Float['... 3']] | Float['3 ... 3']
3959-
) -> Tuple[Float['... 3 3'], Float['... 3']]:
3960-
3961-
if isinstance(three_points, tuple):
3962-
three_points = torch.stack(three_points)
3963-
3964-
# allow for any number of leading dimensions
3965-
3966-
(x1, x2, x3), unpack_one = pack_one(three_points, 'three * d')
3967-
3968-
# main algorithm
3969-
3970-
v1 = x3 - x2
3971-
v2 = x1 - x2
3972-
3973-
e1 = l2norm(v1)
3974-
u2 = v2 - e1 @ (e1.t() @ v2)
3975-
e2 = l2norm(u2)
3976-
3977-
e3 = torch.cross(e1, e2, dim = -1)
3978-
3979-
R = torch.stack((e1, e2, e3), dim = -1)
3980-
t = x2
3981-
3982-
# unpack
3983-
3984-
R = unpack_one(R, '* r1 r2')
3985-
t = unpack_one(t, '* c')
3986-
3987-
return R, t
3988-
39893904
class ComputeAlignmentError(Module):
39903905
""" Algorithm 30 """
39913906

@@ -6364,7 +6279,16 @@ def forward(
63646279
if hard_debug:
63656280
maybe(hard_validate_atom_indices_ascending)(distogram_atom_indices, 'distogram_atom_indices')
63666281
maybe(hard_validate_atom_indices_ascending)(molecule_atom_indices, 'molecule_atom_indices')
6367-
maybe(hard_validate_atom_indices_ascending)(atom_indices_for_frame, 'atom_indices_for_frame')
6282+
6283+
is_biomolecule = ~(
6284+
(~is_molecule_types[..., IS_BIOMOLECULE_INDICES].any(dim=-1))
6285+
| (exists(is_molecule_mod) and is_molecule_mod.any(dim=-1))
6286+
)
6287+
maybe(hard_validate_atom_indices_ascending)(
6288+
atom_indices_for_frame,
6289+
'atom_indices_for_frame',
6290+
mask=is_biomolecule,
6291+
)
63686292

63696293
# soft validate
63706294

@@ -6505,11 +6429,6 @@ def forward(
65056429
mask = molecule_atom_lens > 0
65066430
pairwise_mask = to_pairwise_mask(mask)
65076431

6508-
# prepare mask for msa module and template embedder
6509-
# which is equivalent to the `is_protein` of the `is_molecular_types` input
6510-
6511-
is_protein_mask = is_molecule_types[..., IS_PROTEIN_INDEX]
6512-
65136432
# init recycled single and pairwise
65146433

65156434
detach_when_recycling = default(detach_when_recycling, self.detach_when_recycling)
@@ -6546,7 +6465,6 @@ def forward(
65466465
templates = templates,
65476466
template_mask = template_mask,
65486467
pairwise_repr = pairwise,
6549-
mask = is_protein_mask
65506468
)
65516469

65526470
pairwise = embedded_template + pairwise
@@ -6558,7 +6476,6 @@ def forward(
65586476
msa = msa,
65596477
single_repr = single,
65606478
pairwise_repr = pairwise,
6561-
mask = is_protein_mask,
65626479
msa_mask = msa_mask,
65636480
additional_msa_feats = additional_msa_feats
65646481
)
@@ -6961,12 +6878,9 @@ def forward(
69616878
pred_frames, _ = self.rigid_from_three_points(pred_three_atoms)
69626879

69636880
# determine mask
6964-
# must be residue or nucleotide with greater than 0 atoms
6881+
# must be amino acid, nucleotide, or ligand with greater than 0 atoms
69656882

6966-
align_error_mask = (
6967-
is_molecule_types[..., IS_BIOMOLECULE_INDICES].any(dim=-1)
6968-
& valid_atom_indices_for_frame
6969-
)
6883+
align_error_mask = valid_atom_indices_for_frame
69706884

69716885
# align error
69726886

@@ -6982,7 +6896,7 @@ def forward(
69826896

69836897
pae_labels = distance_to_bins(align_error, self.pae_bins)
69846898

6985-
# set ignore index for invalid molecules or frames (TODO: figure out what is meant by invalid frame)
6899+
# set ignore index for invalid molecules or frames
69866900

69876901
pair_align_error_mask = to_pairwise_mask(align_error_mask)
69886902

0 commit comments

Comments
 (0)