Skip to content

Commit 9e902f1

Browse files
authored
atom resolution distogram (#164)
1 parent 11f589e commit 9e902f1

File tree

3 files changed

+81
-18
lines changed

3 files changed

+81
-18
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 75 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3376,7 +3376,9 @@ def __init__(
33763376
self,
33773377
*,
33783378
dim_pairwise = 128,
3379-
num_dist_bins = 38, # think it is 38?
3379+
num_dist_bins = 38,
3380+
dim_atom = 128,
3381+
atom_resolution = False
33803382
):
33813383
super().__init__()
33823384

@@ -3385,13 +3387,42 @@ def __init__(
33853387
Rearrange('b ... l -> b l ...')
33863388
)
33873389

3390+
# atom resolution
3391+
# for now, just embed per atom distances, sum to atom features, project to pairwise dimension
3392+
3393+
self.atom_resolution = atom_resolution
3394+
3395+
if atom_resolution:
3396+
self.atom_feats_to_pairwise = LinearNoBiasThenOuterSum(dim_atom, dim_pairwise)
3397+
3398+
# tensor typing
3399+
3400+
self.da = dim_atom
3401+
33883402
@typecheck
33893403
def forward(
33903404
self,
3391-
pairwise_repr: Float['b n n d']
3392-
) -> Float['b l n n']:
3405+
pairwise_repr: Float['b n n d'],
3406+
molecule_atom_lens: Int['b n'] | None = None,
3407+
atom_feats: Float['b m {self.da}'] | None = None,
3408+
) -> Float['b l n n'] | Float['b l m m']:
3409+
3410+
if self.atom_resolution:
3411+
assert exists(molecule_atom_lens)
3412+
assert exists(atom_feats)
3413+
3414+
pairwise_repr = batch_repeat_interleave(pairwise_repr, molecule_atom_lens)
3415+
3416+
molecule_atom_lens = repeat(molecule_atom_lens, 'b ... -> (b r) ...', r = pairwise_repr.shape[1])
3417+
pairwise_repr, unpack_one = pack_one(pairwise_repr, '* n d')
3418+
pairwise_repr = batch_repeat_interleave(pairwise_repr, molecule_atom_lens)
3419+
pairwise_repr = unpack_one(pairwise_repr)
3420+
3421+
pairwise_repr = pairwise_repr + self.atom_feats_to_pairwise(atom_feats)
3422+
3423+
symmetrized_pairwise_repr = pairwise_repr + rearrange(pairwise_repr, 'b i j d -> b j i d')
3424+
logits = self.to_distogram_logits(symmetrized_pairwise_repr)
33933425

3394-
logits = self.to_distogram_logits(pairwise_repr)
33953426
return logits
33963427

33973428
# confidence head
@@ -4973,6 +5004,7 @@ def __init__(
49735004
augment_kwargs: dict = dict(),
49745005
stochastic_frame_average = False,
49755006
confidence_head_atom_resolution = False,
5007+
distogram_atom_resolution = False,
49765008
checkpoint_input_embedding = False,
49775009
checkpoint_trunk_pairformer = False,
49785010
checkpoint_diffusion_token_transformer = False,
@@ -5147,9 +5179,13 @@ def __init__(
51475179

51485180
assert len(distance_bins_tensor) == num_dist_bins, '`distance_bins` must have a length equal to the `num_dist_bins` passed in'
51495181

5182+
self.distogram_atom_resolution = distogram_atom_resolution
5183+
51505184
self.distogram_head = DistogramHead(
51515185
dim_pairwise = dim_pairwise,
5152-
num_dist_bins = num_dist_bins
5186+
dim_atom = dim_atom,
5187+
num_dist_bins = num_dist_bins,
5188+
atom_resolution = distogram_atom_resolution,
51535189
)
51545190

51555191
# pae related bins and modules
@@ -5635,22 +5671,41 @@ def forward(
56355671
molecule_pos = None
56365672

56375673
if not exists(distance_labels) and atom_pos_given and exists(distogram_atom_indices):
5638-
# molecule_pos = einx.get_at('b [m] c, b n -> b n c', atom_pos, distogram_atom_indices)
56395674

5640-
distogram_atom_indices = repeat(distogram_atom_indices, 'b n -> b n c', c = atom_pos.shape[-1])
5641-
molecule_pos = atom_pos.gather(1, distogram_atom_indices)
5675+
distogram_pos = atom_pos
56425676

5643-
molecule_dist = torch.cdist(molecule_pos, molecule_pos, p = 2)
5644-
distance_labels = distance_to_bins(molecule_dist, self.distance_bins)
5677+
if not self.distogram_atom_resolution:
5678+
# molecule_pos = einx.get_at('b [m] c, b n -> b n c', atom_pos, distogram_atom_indices)
5679+
5680+
distogram_atom_indices = repeat(distogram_atom_indices, 'b n -> b n c', c = distogram_pos.shape[-1])
5681+
molecule_pos = distogram_pos = distogram_pos.gather(1, distogram_atom_indices)
5682+
distogram_mask = valid_distogram_mask
5683+
else:
5684+
distogram_mask = atom_mask
5685+
5686+
distogram_dist = torch.cdist(distogram_pos, distogram_pos, p = 2)
5687+
distance_labels = distance_to_bins(distogram_dist, self.distance_bins)
56455688

56465689
# account for representative distogram atom missing from residue (-1 set on distogram_atom_indices field)
56475690

5648-
valid_distogram_mask = to_pairwise_mask(valid_distogram_mask)
5649-
distance_labels.masked_fill_(~valid_distogram_mask, ignore)
5691+
distogram_mask = to_pairwise_mask(distogram_mask)
5692+
distance_labels.masked_fill_(~distogram_mask, ignore)
56505693

56515694
if exists(distance_labels):
5652-
distance_labels = torch.where(pairwise_mask, distance_labels, ignore)
5653-
distogram_logits = self.distogram_head(pairwise)
5695+
5696+
distogram_mask = pairwise_mask
5697+
5698+
if self.distogram_atom_resolution:
5699+
distogram_mask = to_pairwise_mask(atom_mask)
5700+
5701+
distance_labels = torch.where(distogram_mask, distance_labels, ignore)
5702+
5703+
distogram_logits = self.distogram_head(
5704+
pairwise,
5705+
molecule_atom_lens = molecule_atom_lens,
5706+
atom_feats = atom_feats
5707+
)
5708+
56545709
distogram_loss = F.cross_entropy(distogram_logits, distance_labels, ignore_index = ignore)
56555710

56565711
# otherwise, noise and make it learn to denoise
@@ -5771,7 +5826,12 @@ def forward(
57715826
denoised_molecule_pos = None
57725827

57735828
if not ch_atom_res:
5774-
assert exists(molecule_pos), '`distogram_atom_indices` must be passed in for calculating non-atomic PAE labels'
5829+
if not exists(molecule_pos):
5830+
assert exists(distogram_atom_indices), '`distogram_atom_indices` must be passed in for calculating non-atomic PAE labels'
5831+
5832+
distogram_atom_indices = repeat(distogram_atom_indices, 'b n -> b n c', c = distogram_pos.shape[-1])
5833+
molecule_pos = atom_pos.gather(1, distogram_atom_indices)
5834+
57755835
denoised_molecule_pos = denoised_atom_pos.gather(1, distogram_atom_indices)
57765836

57775837
# three_atoms = einx.get_at('b [m] c, b n three -> three b n c', atom_pos, atom_indices_for_frame)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.3.5"
3+
version = "0.3.6"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

tests/test_af3.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,14 +554,16 @@ def test_distogram_head():
554554
@pytest.mark.parametrize('atom_transformer_intramolecular_attn', (True, False))
555555
@pytest.mark.parametrize('num_molecule_mods', (0, 4))
556556
@pytest.mark.parametrize('confidence_head_atom_resolution', (True, False))
557+
@pytest.mark.parametrize('distogram_atom_resolution', (True, False))
557558
def test_alphafold3(
558559
window_atompair_inputs: bool,
559560
stochastic_frame_average: bool,
560561
missing_atoms: bool,
561562
calculate_pae: bool,
562563
atom_transformer_intramolecular_attn: bool,
563564
num_molecule_mods: int,
564-
confidence_head_atom_resolution: bool
565+
confidence_head_atom_resolution: bool,
566+
distogram_atom_resolution: bool
565567
):
566568
seq_len = 16
567569
atoms_per_window = 27
@@ -655,7 +657,8 @@ def test_alphafold3(
655657
)
656658
),
657659
stochastic_frame_average = stochastic_frame_average,
658-
confidence_head_atom_resolution = confidence_head_atom_resolution
660+
confidence_head_atom_resolution = confidence_head_atom_resolution,
661+
distogram_atom_resolution = distogram_atom_resolution
659662
)
660663

661664
loss, breakdown = alphafold3(

0 commit comments

Comments
 (0)