Skip to content

Commit d18f69c

Browse files
committed
completely move PAE over to atom resolution, PDE yet to go
1 parent 314271d commit d18f69c

File tree

3 files changed

+65
-74
lines changed

3 files changed

+65
-74
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 51 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -3099,31 +3099,22 @@ def __init__(
30993099
@typecheck
31003100
def forward(
31013101
self,
3102-
pred_coords: Float['b m_or_n 3'],
3103-
true_coords: Float['b m_or_n 3'],
3102+
pred_coords: Float['b m 3'],
3103+
true_coords: Float['b m 3'],
31043104
pred_frames: Float['b n 3 3'],
31053105
true_frames: Float['b n 3 3'],
3106-
mask: Bool['b m_or_n'] | None = None,
3107-
molecule_atom_lens: Int['b n'] | None = None
3108-
) -> Float['b m_or_n m_or_n']:
3106+
molecule_atom_lens: Int['b n'],
3107+
mask: Bool['b m'] | None = None,
3108+
) -> Float['b m m']:
31093109
"""
31103110
pred_coords: predicted coordinates
31113111
true_coords: true coordinates
31123112
pred_frames: predicted frames
31133113
true_frames: true frames
31143114
"""
31153115

3116-
# detect whether using atom or residue resolution
3117-
3118-
is_atom_resolution = pred_coords.shape[1] != pred_frames.shape[1]
3119-
assert not is_atom_resolution or exists(molecule_atom_lens), '`molecule_atom_lens` must be passed in for atom resolution alignment error'
3120-
3121-
if is_atom_resolution:
3122-
pred_frames = batch_repeat_interleave(pred_frames, molecule_atom_lens)
3123-
true_frames = batch_repeat_interleave(true_frames, molecule_atom_lens)
3124-
3125-
if not exists(mask) and exists(molecule_atom_lens):
3126-
mask = batch_repeat_interleave(molecule_atom_lens > 0, molecule_atom_lens)
3116+
pred_frames = batch_repeat_interleave(pred_frames, molecule_atom_lens)
3117+
true_frames = batch_repeat_interleave(true_frames, molecule_atom_lens)
31273118

31283119
# to pairs
31293120

@@ -3445,7 +3436,7 @@ def forward(
34453436

34463437
class ConfidenceHeadLogits(NamedTuple):
34473438
pae: Float['b pae m m'] | None
3448-
pde: Float['b pde m m']
3439+
pde: Float['b pde n n']
34493440
plddt: Float['b plddt m']
34503441
resolved: Float['b 2 m']
34513442

@@ -3514,6 +3505,8 @@ def __init__(
35143505

35153506
self.atom_feats_to_single = LinearNoBias(dim_atom, dim_single)
35163507

3508+
self.atom_feats_to_pairwise = LinearNoBiasThenOuterSum(dim_atom, dim_pairwise)
3509+
35173510
# tensor typing
35183511

35193512
self.da = dim_atom
@@ -3570,12 +3563,16 @@ def forward(
35703563
single_repr=single_repr, pairwise_repr=pairwise_repr, mask=mask
35713564
)
35723565

3573-
# handle atom level resolution
3566+
# handle atom level resolution for single and pairwise
35743567

35753568
atom_single_repr = batch_repeat_interleave(single_repr, molecule_atom_lens)
35763569

35773570
atom_single_repr = atom_single_repr + self.atom_feats_to_single(atom_feats)
35783571

3572+
atom_pairwise_repr = batch_repeat_interleave_pairwise(pairwise_repr, molecule_atom_lens)
3573+
3574+
atom_pairwise_repr = atom_pairwise_repr + self.atom_feats_to_pairwise(atom_feats)
3575+
35793576
# to logits
35803577

35813578
pde_logits = self.to_pde_logits(symmetrize(pairwise_repr))
@@ -3588,7 +3585,7 @@ def forward(
35883585
pae_logits = None
35893586

35903587
if return_pae_logits:
3591-
pae_logits = self.to_pae_logits(pairwise_repr)
3588+
pae_logits = self.to_pae_logits(atom_pairwise_repr)
35923589

35933590
# return all logits
35943591

@@ -3642,11 +3639,12 @@ def _calculate_bin_centers(
36423639
@typecheck
36433640
def forward(
36443641
self,
3645-
confidence_head_logits: ConfidenceHeadLogits,
3642+
pae_logits: Float["b pae m m"],
3643+
plddt_logits: Float["b plddt m"],
36463644
asym_id: Int["b n"],
36473645
has_frame: Bool["b n"],
3646+
molecule_atom_lens: Int["b n"],
36483647
ptm_residue_weight: Float["b n"] | None = None,
3649-
molecule_atom_lens: Int["b n"] | None = None,
36503648
multimer_mode: bool = True,
36513649
) -> ConfidenceScore:
36523650
"""Main function to compute confidence score.
@@ -3658,21 +3656,19 @@ def forward(
36583656
:param multimer_mode: bool
36593657
:return: Confidence score
36603658
"""
3661-
plddt = self.compute_plddt(confidence_head_logits.plddt)
3659+
plddt = self.compute_plddt(plddt_logits)
36623660

36633661
# Section 5.9.1 equation 17
36643662
ptm = self.compute_ptm(
3665-
confidence_head_logits.pae, asym_id, has_frame, ptm_residue_weight, interface=False,
3666-
molecule_atom_lens=molecule_atom_lens,
3663+
pae_logits, asym_id, has_frame, molecule_atom_lens, ptm_residue_weight, interface=False
36673664
)
36683665

36693666
iptm = None
36703667

36713668
if multimer_mode:
36723669
# Section 5.9.2 equation 18
36733670
iptm = self.compute_ptm(
3674-
confidence_head_logits.pae, asym_id, has_frame, ptm_residue_weight, interface=True,
3675-
molecule_atom_lens=molecule_atom_lens,
3671+
pae_logits, asym_id, has_frame, molecule_atom_lens, ptm_residue_weight, interface=True
36763672
)
36773673

36783674
confidence_score = ConfidenceScore(plddt=plddt, ptm=ptm, iptm=iptm)
@@ -3700,11 +3696,11 @@ def compute_plddt(
37003696
@typecheck
37013697
def compute_ptm(
37023698
self,
3703-
logits: Float["b pae m_or_n m_or_n"],
3699+
logits: Float["b pae m m"],
37043700
asym_id: Int["b n"],
37053701
has_frame: Bool["b n"],
3702+
molecule_atom_lens: Int["b n"],
37063703
residue_weights: Float["b n"] | None = None,
3707-
molecule_atom_lens: Int["b n"] | None = None,
37083704
interface: bool = False,
37093705
compute_chain_wise_iptm: bool = False,
37103706
) -> Float[" b"] | Tuple[Float["b chains chains"], Bool["b chains chains"], Int["b chains"]]:
@@ -3720,14 +3716,11 @@ def compute_ptm(
37203716
:return: pTM
37213717
"""
37223718

3723-
is_atom_resolution = logits.shape[-1] != asym_id.shape[-1]
3724-
assert not is_atom_resolution or exists(molecule_atom_lens), '`molecule_atom_lens` must be passed in for atom resolution pTM'
37253719

3726-
if is_atom_resolution:
3727-
asym_id = batch_repeat_interleave(asym_id, molecule_atom_lens)
3728-
has_frame = batch_repeat_interleave(has_frame, molecule_atom_lens)
3729-
if exists(residue_weights):
3730-
residue_weights = batch_repeat_interleave(residue_weights, molecule_atom_lens)
3720+
asym_id = batch_repeat_interleave(asym_id, molecule_atom_lens)
3721+
has_frame = batch_repeat_interleave(has_frame, molecule_atom_lens)
3722+
if exists(residue_weights):
3723+
residue_weights = batch_repeat_interleave(residue_weights, molecule_atom_lens)
37313724

37323725
if not exists(residue_weights):
37333726
residue_weights = torch.ones_like(has_frame)
@@ -3988,10 +3981,10 @@ def compute_disorder(
39883981
disorder = ((atom_rasa > 0.581) * mask).sum(dim=-1) / (self.eps + mask.sum(dim=1))
39893982
return disorder
39903983

3991-
@typecheck
39923984
def compute_full_complex_metric(
39933985
self,
3994-
confidence_head_logits: ConfidenceHeadLogits,
3986+
pae_logits: Float['b pae m m'],
3987+
plddt_logits: Float['b plddt m'],
39953988
asym_id: Int["b n"],
39963989
has_frame: Bool["b n"],
39973990
molecule_atom_lens: Int["b n"],
@@ -4003,7 +3996,8 @@ def compute_full_complex_metric(
40033996

40043997
"""Compute full complex metric.
40053998
4006-
:param confidence_head_logits: ConfidenceHeadLogits
3999+
:param pae_logits: pae logits from confidence head
4000+
:param plddt_logits: plddt logits from confidence head
40074001
:param asym_id: [b n] asym_id of each residue
40084002
:param has_frame: [b n] has_frame of each residue
40094003
:param molecule_atom_lens: [b n] molecule atom lens
@@ -4035,7 +4029,7 @@ def compute_full_complex_metric(
40354029
atom_is_molecule_types = is_molecule_types.gather(1, indices) * valid_indices[..., None]
40364030

40374031
confidence_score = self.compute_confidence_score(
4038-
confidence_head_logits, asym_id, has_frame, multimer_mode=True
4032+
pae_logits, plddt_logits, asym_id, has_frame, molecule_atom_lens, multimer_mode=True
40394033
)
40404034
has_clash = self.compute_clash(
40414035
atom_pos,
@@ -4062,9 +4056,11 @@ def compute_full_complex_metric(
40624056
@typecheck
40634057
def compute_single_chain_metric(
40644058
self,
4065-
confidence_head_logits: ConfidenceHeadLogits,
4059+
pae_logits: Float['b pae m m'],
4060+
plddt_logits: Float['b plddt m'],
40664061
asym_id: Int["b n"],
4067-
has_frame: Bool["b n"],
4062+
has_frame: Bool["b n"],
4063+
molecule_atom_lens: Int["b n"]
40684064
) -> Float[" b"]:
40694065

40704066
"""Compute single chain metric.
@@ -4078,18 +4074,18 @@ def compute_single_chain_metric(
40784074
# Section 5.9.3.2
40794075

40804076
confidence_score = self.compute_confidence_score(
4081-
confidence_head_logits, asym_id, has_frame, multimer_mode=False
4077+
pae_logits, plddt_logits, asym_id, has_frame, molecule_atom_lens, multimer_mode=False
40824078
)
40834079

4084-
score = confidence_score.ptm
4085-
return score
4080+
return confidence_score.ptm
40864081

40874082
@typecheck
40884083
def compute_interface_metric(
40894084
self,
4090-
confidence_head_logits: ConfidenceHeadLogits,
4085+
pae_logits: Float['b pae m m'],
40914086
asym_id: Int["b n"],
4092-
has_frame: Bool["b n"],
4087+
has_frame: Bool["b n"],
4088+
molecule_atom_lens: Int['b n'],
40934089
interface_chains: List,
40944090
) -> Float[" b"]:
40954091
"""Compute interface metric.
@@ -4116,7 +4112,7 @@ def compute_interface_metric(
41164112
chain_wise_iptm_mask,
41174113
unique_chains,
41184114
) = self.compute_confidence_score.compute_ptm(
4119-
confidence_head_logits.pae, asym_id, has_frame, compute_chain_wise_iptm=True
4115+
pae_logits, asym_id, has_frame, molecule_atom_lens, compute_chain_wise_iptm=True
41204116
)
41214117

41224118
# Section 5.9.3 equation 20
@@ -4141,7 +4137,7 @@ def compute_interface_metric(
41414137
@typecheck
41424138
def compute_modified_residue_score(
41434139
self,
4144-
confidence_head_logits: ConfidenceHeadLogits,
4140+
plddt_logits: Float['b plddt m'],
41454141
atom_mask: Bool["b m"],
41464142
atom_is_modified_residue: Int["b m"],
41474143
) -> Float[" b"]:
@@ -4155,9 +4151,7 @@ def compute_modified_residue_score(
41554151

41564152
# Section 5.9.3.4
41574153

4158-
plddt = self.compute_confidence_score.compute_plddt(
4159-
confidence_head_logits.plddt,
4160-
)
4154+
plddt = self.compute_confidence_score.compute_plddt(plddt_logits)
41614155

41624156
mask = atom_is_modified_residue * atom_mask
41634157
plddt_mean = masked_average(plddt, mask, dim=-1, eps=self.eps)
@@ -5909,11 +5903,13 @@ def forward(
59095903
& valid_atom_indices_for_frame
59105904
)
59115905

5906+
align_error_mask = batch_repeat_interleave(align_error_mask, molecule_atom_lens)
5907+
59125908
# align error
59135909

59145910
align_error = self.compute_alignment_error(
5915-
denoised_molecule_pos,
5916-
molecule_pos,
5911+
denoised_atom_pos,
5912+
atom_pos,
59175913
pred_frames,
59185914
frames,
59195915
mask=align_error_mask,
@@ -6043,8 +6039,7 @@ def forward(
60436039
# determine which mask to use for confidence head labels
60446040

60456041
label_mask = atom_mask
6046-
6047-
label_pairwise_mask = to_pairwise_mask(mask)
6042+
label_pairwise_mask = to_pairwise_mask(atom_mask)
60486043

60496044
# cross entropy losses
60506045

@@ -6076,7 +6071,7 @@ def cross_entropy_with_weight(logits, labels, weight, ignore_index: int):
60766071
f"pde_labels shape {pde_labels.shape[-1]} does not match "
60776072
f"ch_logits.pde shape {ch_logits.pde.shape[-1]}"
60786073
)
6079-
pde_labels = torch.where(label_pairwise_mask, pde_labels, ignore)
6074+
pde_labels = torch.where(to_pairwise_mask(mask), pde_labels, ignore)
60806075
pde_loss = cross_entropy_with_weight(ch_logits.pde, pde_labels, confidence_weight, ignore)
60816076

60826077
if exists(plddt_labels):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.3.11"
3+
version = "0.3.12"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

tests/test_af3.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -186,13 +186,14 @@ def test_rigid_from_three_points():
186186
assert rotation.shape == (7, 11, 23, 3, 3)
187187

188188
def test_compute_alignment_error():
189+
molecule_atom_lens = torch.ones(2, 100).long()
189190
pred_coords = torch.randn(2, 100, 3)
190191
pred_frames = torch.randn(2, 100, 3, 3)
191192

192193
# `pred_coords` should match itself in frame basis
193194

194195
error_fn = ComputeAlignmentError()
195-
alignment_errors = error_fn(pred_coords, pred_coords, pred_frames, pred_frames)
196+
alignment_errors = error_fn(pred_coords, pred_coords, pred_frames, pred_frames, molecule_atom_lens = molecule_atom_lens)
196197

197198
assert alignment_errors.shape == (2, 100, 100)
198199
assert (alignment_errors.mean(-1) < 1e-3).all()
@@ -479,7 +480,7 @@ def test_confidence_head():
479480
mask=mask,
480481
)
481482

482-
assert logits.pae.shape[-1] == seq_len
483+
assert logits.pae.shape[-1] == atom_seq_len
483484
assert logits.pde.shape[-1] == seq_len
484485

485486
assert logits.plddt.shape[-1] == atom_seq_len
@@ -995,12 +996,12 @@ def test_compute_ranking_score():
995996
has_frame = torch.randint(0, 2, (batch_size, seq_len)).bool()
996997
is_modified_residue = torch.randint(0, 2, (batch_size, atom_seq_len))
997998

998-
pae_logits = torch.randn(batch_size, 64, seq_len, seq_len)
999+
pae_logits = torch.randn(batch_size, 64, atom_seq_len, atom_seq_len)
9991000
pde_logits = torch.randn(batch_size, 64, seq_len, seq_len)
10001001
plddt_logits = torch.randn(batch_size, 50, atom_seq_len)
1001-
resolved_logits = torch.randint(0, 2, (batch_size, 2, seq_len))
1002+
resolved_logits = torch.randint(0, 2, (batch_size, 2, atom_seq_len))
1003+
10021004
confidence_head_logits = ConfidenceHeadLogits(pae_logits, pde_logits, plddt_logits, resolved_logits)
1003-
atom_level_pae_logits = torch.randn(batch_size, 64, atom_seq_len, atom_seq_len)
10041005

10051006
chain_length = [random.randint(seq_len // 4, seq_len //2)
10061007
for _ in range(batch_size)]
@@ -1010,37 +1011,32 @@ def test_compute_ranking_score():
10101011
for chain_len in chain_length
10111012
]).long()
10121013

1013-
10141014
compute_ranking_score = ComputeRankingScore()
10151015

10161016
full_complex_metric = compute_ranking_score.compute_full_complex_metric(
1017-
confidence_head_logits, asym_id, has_frame, molecule_atom_lens,
1018-
atom_pos, atom_mask, is_molecule_types)
1017+
pae_logits, plddt_logits, asym_id, has_frame, molecule_atom_lens,
1018+
atom_pos, atom_mask, is_molecule_types
1019+
)
10191020

10201021
single_chain_metric = compute_ranking_score.compute_single_chain_metric(
1021-
confidence_head_logits, asym_id, has_frame,)
1022+
pae_logits, plddt_logits, asym_id, has_frame, molecule_atom_lens)
10221023

10231024
interface_metric = compute_ranking_score.compute_interface_metric(
1024-
confidence_head_logits, asym_id, has_frame,
1025+
pae_logits, asym_id, has_frame, molecule_atom_lens,
10251026
interface_chains=[(0, 1), (1,)])
10261027

10271028
modified_residue_score = compute_ranking_score.compute_modified_residue_score(
1028-
confidence_head_logits, atom_mask, is_modified_residue)
1029-
1030-
residue_level_ptm_score = compute_ranking_score.compute_confidence_score.compute_ptm(
1031-
pae_logits, asym_id, has_frame
1032-
)
1029+
plddt_logits, atom_mask, is_modified_residue)
10331030

10341031
atom_level_ptm_score = compute_ranking_score.compute_confidence_score.compute_ptm(
1035-
atom_level_pae_logits, asym_id, has_frame,
1032+
pae_logits, asym_id, has_frame,
10361033
molecule_atom_lens=molecule_atom_lens
10371034
)
10381035

10391036
assert full_complex_metric.numel() == batch_size
10401037
assert single_chain_metric.numel() == batch_size
10411038
assert interface_metric.numel() == batch_size
10421039
assert modified_residue_score.numel() == batch_size
1043-
assert residue_level_ptm_score.numel() == batch_size
10441040
assert atom_level_ptm_score.numel() == batch_size
10451041

10461042
def test_model_selection_score():

0 commit comments

Comments
 (0)