Skip to content

Commit 7e65978

Browse files
committed
improvise a solution for atom resolution confidence heads, without atom attention
1 parent 0be6e64 commit 7e65978

File tree

3 files changed

+90
-9
lines changed

3 files changed

+90
-9
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2853,10 +2853,10 @@ def forward(
28532853
# confidence head
28542854

28552855
class ConfidenceHeadLogits(NamedTuple):
2856-
pae: Float['b pae n n'] | None
2857-
pde: Float['b pde n n']
2858-
plddt: Float['b plddt n']
2859-
resolved: Float['b 2 n']
2856+
pae: Float['b pae n n'] | Float['b pae m m'] | None
2857+
pde: Float['b pde n n'] | Float['b pde m m']
2858+
plddt: Float['b plddt n'] | Float['b plddt m']
2859+
resolved: Float['b 2 n'] | Float['b 2 m']
28602860

28612861
class ConfidenceHead(Module):
28622862
""" Algorithm 31 """
@@ -2866,6 +2866,8 @@ def __init__(
28662866
self,
28672867
*,
28682868
dim_single_inputs,
2869+
atom_resolution = False, # @amorehead discovers that the public api has per-atom resolution confidences. improvise a solution
2870+
dim_atom = 128,
28692871
atompair_dist_bins: List[float],
28702872
dim_single = 384,
28712873
dim_pairwise = 128,
@@ -2918,6 +2920,19 @@ def __init__(
29182920
Rearrange('b ... l -> b l ...')
29192921
)
29202922

2923+
# atom resolution
2924+
# for now, just embed per atom distances, sum to atom features, project to pairwise dimension
2925+
2926+
self.atom_resolution = atom_resolution
2927+
2928+
if atom_resolution:
2929+
self.atom_feats_to_single = LinearNoBias(dim_atom, dim_single)
2930+
self.atom_feats_to_pairwise = LinearNoBiasThenOuterSum(dim_atom, dim_pairwise)
2931+
2932+
# tensor typing
2933+
2934+
self.da = dim_atom
2935+
29212936
@typecheck
29222937
def forward(
29232938
self,
@@ -2927,7 +2942,9 @@ def forward(
29272942
pairwise_repr: Float['b n n dp'],
29282943
pred_atom_pos: Float['b n 3'] | Float['b m 3'],
29292944
molecule_atom_indices: Int['b n'] | None = None,
2945+
molecule_atom_lens: Int['b n'] | None = None,
29302946
mask: Bool['b n'] | None = None,
2947+
atom_feats: Float['b m {self.da}'] | None = None,
29312948
return_pae_logits = True
29322949

29332950
) -> ConfidenceHeadLogits:
@@ -2938,16 +2955,24 @@ def forward(
29382955

29392956
is_atom_seq = pred_atom_pos.shape[-2] > single_inputs_repr.shape[-2]
29402957

2941-
assert not is_atom_seq or exists(molecule_atom_indices)
2958+
# handle atom resolution vs not
2959+
2960+
if self.atom_resolution:
2961+
assert exists(atom_feats), 'atom_feats must be passed in if atom_resolution is turned on for ConfidenceHead'
2962+
assert is_atom_seq, '`pred_atom_pos` must be passed in with atomic length'
2963+
assert exists(molecule_atom_lens)
29422964

29432965
if is_atom_seq:
2944-
pred_atom_pos = einx.get_at('b [m] c, b n -> b n c', pred_atom_pos, molecule_atom_indices)
2966+
assert exists(molecule_atom_indices), 'molecule_atom_indices must be passed into ConfidenceHead if pred_atom_pos is atomic length'
2967+
pred_molecule_pos = einx.get_at('b [m] c, b n -> b n c', pred_atom_pos, molecule_atom_indices)
2968+
else:
2969+
pred_molecule_pos = pred_atom_pos
29452970

29462971
# interatomic distances - embed and add to pairwise
29472972

2948-
interatom_dist = torch.cdist(pred_atom_pos, pred_atom_pos, p = 2)
2973+
intermolecule_dist = torch.cdist(pred_molecule_pos, pred_molecule_pos, p = 2)
29492974

2950-
dist_from_dist_bins = einx.subtract('b m dist, dist_bins -> b m dist dist_bins', interatom_dist, self.atompair_dist_bins).abs()
2975+
dist_from_dist_bins = einx.subtract('b m dist, dist_bins -> b m dist dist_bins', intermolecule_dist, self.atompair_dist_bins).abs()
29512976
dist_bin_indices = dist_from_dist_bins.argmin(dim = -1)
29522977
pairwise_repr = pairwise_repr + self.dist_bin_pairwise_embed(dist_bin_indices)
29532978

@@ -2959,6 +2984,27 @@ def forward(
29592984
mask = mask
29602985
)
29612986

2987+
# handle maybe atom level resolution
2988+
2989+
if self.atom_resolution:
2990+
single_repr = repeat_consecutive_with_lens(single_repr, molecule_atom_lens)
2991+
2992+
pairwise_repr = repeat_consecutive_with_lens(pairwise_repr, molecule_atom_lens)
2993+
2994+
molecule_atom_lens = repeat(molecule_atom_lens, 'b ... -> (b r) ...', r = pairwise_repr.shape[1])
2995+
pairwise_repr, ps = pack_one(pairwise_repr, '* n d')
2996+
pairwise_repr = repeat_consecutive_with_lens(pairwise_repr, molecule_atom_lens)
2997+
pairwise_repr = unpack_one(pairwise_repr, ps, '* n d')
2998+
2999+
interatomic_dist = torch.cdist(pred_atom_pos, pred_atom_pos, p = 2)
3000+
3001+
dist_from_dist_bins = einx.subtract('b m dist, dist_bins -> b m dist dist_bins', interatomic_dist, self.atompair_dist_bins).abs()
3002+
dist_bin_indices = dist_from_dist_bins.argmin(dim = -1)
3003+
pairwise_repr = pairwise_repr + self.dist_bin_pairwise_embed(dist_bin_indices)
3004+
3005+
single_repr = single_repr + self.atom_feats_to_single(atom_feats)
3006+
pairwise_repr = pairwise_repr + self.atom_feats_to_pairwise(atom_feats)
3007+
29623008
# to logits
29633009

29643010
symmetric_pairwise_repr = pairwise_repr + rearrange(pairwise_repr, 'b i j d -> b j i d')

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.10"
3+
version = "0.2.11"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,41 @@ def test_confidence_head():
388388
mask = mask
389389
)
390390

391+
def test_atom_resolution_confidence_head():
392+
single_inputs_repr = torch.randn(2, 16, 77)
393+
single_repr = torch.randn(2, 16, 384)
394+
pairwise_repr = torch.randn(2, 16, 16, 128)
395+
mask = torch.ones((2, 16)).bool()
396+
397+
atom_seq_len = 32
398+
atom_feats = torch.randn(2, atom_seq_len, 64)
399+
pred_atom_pos = torch.randn(2, atom_seq_len, 3)
400+
401+
molecule_atom_indices = torch.randint(0, atom_seq_len, (2, 16)).long()
402+
molecule_atom_lens = torch.full((2, 16), 2).long()
403+
404+
confidence_head = ConfidenceHead(
405+
dim_single_inputs = 77,
406+
atom_resolution = True,
407+
dim_atom = 64,
408+
atompair_dist_bins = torch.linspace(3, 20, 37).tolist(),
409+
dim_single = 384,
410+
dim_pairwise = 128,
411+
)
412+
413+
logits = confidence_head(
414+
single_inputs_repr = single_inputs_repr,
415+
single_repr = single_repr,
416+
atom_feats = atom_feats,
417+
molecule_atom_indices = molecule_atom_indices,
418+
molecule_atom_lens = molecule_atom_lens,
419+
pairwise_repr = pairwise_repr,
420+
pred_atom_pos = pred_atom_pos,
421+
mask = mask
422+
)
423+
424+
assert logits.pde.shape[-1] == atom_seq_len
425+
391426
def test_input_embedder():
392427

393428
molecule_atom_lens = torch.randint(0, 3, (2, 16))

0 commit comments

Comments
 (0)