Skip to content

Commit 52653ab

Browse files
committed
back to token level pae
1 parent bc9e7e8 commit 52653ab

File tree

2 files changed

+16
-50
lines changed

2 files changed

+16
-50
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 16 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3134,32 +3134,18 @@ def __init__(
31343134
@typecheck
31353135
def forward(
31363136
self,
3137-
pred_coords: Float['b m_or_n 3'],
3138-
true_coords: Float['b m_or_n 3'],
3137+
pred_coords: Float['b n 3'],
3138+
true_coords: Float['b n 3'],
31393139
pred_frames: Float['b n 3 3'],
31403140
true_frames: Float['b n 3 3'],
3141-
mask: Bool['b m_or_n'] | None = None,
3142-
molecule_atom_lens: Int['b n'] | None = None
3143-
) -> Float['b m_or_n m_or_n']:
3141+
mask: Bool['b n'] | None = None,
3142+
) -> Float['b n n']:
31443143
"""
31453144
pred_coords: predicted coordinates
31463145
true_coords: true coordinates
31473146
pred_frames: predicted frames
31483147
true_frames: true frames
31493148
"""
3150-
3151-
# detect whether using atom or residue resolution
3152-
3153-
is_atom_resolution = pred_coords.shape[1] != pred_frames.shape[1]
3154-
assert not is_atom_resolution or exists(molecule_atom_lens), '`molecule_atom_lens` must be passed in for atom resolution alignment error'
3155-
3156-
if is_atom_resolution:
3157-
pred_frames = batch_repeat_interleave(pred_frames, molecule_atom_lens)
3158-
true_frames = batch_repeat_interleave(true_frames, molecule_atom_lens)
3159-
3160-
if not exists(mask) and exists(molecule_atom_lens):
3161-
mask = batch_repeat_interleave(molecule_atom_lens > 0, molecule_atom_lens)
3162-
31633149
# to pairs
31643150

31653151
seq = pred_coords.shape[1]
@@ -3681,7 +3667,6 @@ def forward(
36813667
asym_id: Int["b n"],
36823668
has_frame: Bool["b n"],
36833669
ptm_residue_weight: Float["b n"] | None = None,
3684-
molecule_atom_lens: Int["b n"] | None = None,
36853670
multimer_mode: bool = True,
36863671
) -> ConfidenceScore:
36873672
"""Main function to compute confidence score.
@@ -3698,7 +3683,6 @@ def forward(
36983683
# Section 5.9.1 equation 17
36993684
ptm = self.compute_ptm(
37003685
confidence_head_logits.pae, asym_id, has_frame, ptm_residue_weight, interface=False,
3701-
molecule_atom_lens=molecule_atom_lens,
37023686
)
37033687

37043688
iptm = None
@@ -3707,7 +3691,6 @@ def forward(
37073691
# Section 5.9.2 equation 18
37083692
iptm = self.compute_ptm(
37093693
confidence_head_logits.pae, asym_id, has_frame, ptm_residue_weight, interface=True,
3710-
molecule_atom_lens=molecule_atom_lens,
37113694
)
37123695

37133696
confidence_score = ConfidenceScore(plddt=plddt, ptm=ptm, iptm=iptm)
@@ -3735,11 +3718,10 @@ def compute_plddt(
37353718
@typecheck
37363719
def compute_ptm(
37373720
self,
3738-
logits: Float["b pae m_or_n m_or_n"],
3721+
pae_logits: Float["b pae n n"],
37393722
asym_id: Int["b n"],
37403723
has_frame: Bool["b n"],
37413724
residue_weights: Float["b n"] | None = None,
3742-
molecule_atom_lens: Int["b n"] | None = None,
37433725
interface: bool = False,
37443726
compute_chain_wise_iptm: bool = False,
37453727
) -> Float[" b"] | Tuple[Float["b chains chains"], Bool["b chains chains"], Int["b chains"]]:
@@ -3753,25 +3735,15 @@ def compute_ptm(
37533735
:param interface: bool
37543736
:param compute_chain_wise_iptm: bool
37553737
:return: pTM
3756-
"""
3757-
3758-
is_atom_resolution = logits.shape[-1] != asym_id.shape[-1]
3759-
assert not is_atom_resolution or exists(molecule_atom_lens), '`molecule_atom_lens` must be passed in for atom resolution pTM'
3760-
3761-
if is_atom_resolution:
3762-
asym_id = batch_repeat_interleave(asym_id, molecule_atom_lens)
3763-
has_frame = batch_repeat_interleave(has_frame, molecule_atom_lens)
3764-
if exists(residue_weights):
3765-
residue_weights = batch_repeat_interleave(residue_weights, molecule_atom_lens)
3766-
3738+
"""
37673739
if not exists(residue_weights):
37683740
residue_weights = torch.ones_like(has_frame)
37693741

37703742
residue_weights = residue_weights * has_frame
37713743

3772-
num_batch = logits.shape[0]
3773-
num_res = logits.shape[-1]
3774-
logits = rearrange(logits, "b c i j -> b i j c")
3744+
num_batch, *_, num_res, device = *pae_logits.shape, pae_logits.device
3745+
3746+
pae_logits = rearrange(pae_logits, "b c i j -> b i j c")
37753747

37763748
bin_centers = self._calculate_bin_centers(self.pae_breaks)
37773749

@@ -3788,7 +3760,7 @@ def compute_ptm(
37883760
tm_per_bin = 1.0 / (1 + torch.square(bin_centers[None, :]) / torch.square(d0[..., None]))
37893761

37903762
# Convert logits to probs.
3791-
probs = F.softmax(logits, dim=-1)
3763+
probs = F.softmax(pae_logits, dim=-1)
37923764

37933765
# E_distances tm(distance).
37943766
predicted_tm_term = einsum(probs, tm_per_bin, "b i j pae, b pae -> b i j ")
@@ -3801,7 +3773,7 @@ def compute_ptm(
38013773
max_chains = max(len(chains) for chains in unique_chains)
38023774

38033775
chain_wise_iptm = torch.zeros(
3804-
(num_batch, max_chains, max_chains), device=logits.device
3776+
(num_batch, max_chains, max_chains), device=device
38053777
)
38063778
chain_wise_iptm_mask = torch.zeros_like(chain_wise_iptm).bool()
38073779

@@ -3837,7 +3809,7 @@ def compute_ptm(
38373809
return chain_wise_iptm, chain_wise_iptm_mask, torch.tensor(unique_chains)
38383810

38393811
else:
3840-
pair_mask = torch.ones(size=(num_batch, num_res, num_res), device=logits.device).bool()
3812+
pair_mask = torch.ones(size=(num_batch, num_res, num_res), device=device).bool()
38413813
if interface:
38423814
pair_mask *= asym_id[:, :, None] != asym_id[:, None, :]
38433815

@@ -3857,13 +3829,14 @@ def compute_ptm(
38573829
@typecheck
38583830
def compute_pde(
38593831
self,
3860-
logits: Float["b pde n n"],
3832+
pde_logits: Float["b pde n n"],
38613833
tok_repr_atm_mask: Bool["b n"],
38623834
) -> Float["b n n"]:
38633835
"""Compute PDE from logits."""
3864-
logits = rearrange(logits, "b pde i j -> b i j pde")
3836+
3837+
pde_logits = rearrange(pde_logits, "b pde i j -> b i j pde")
38653838
bin_centers = self._calculate_bin_centers(self.pde_breaks)
3866-
probs = F.softmax(logits, dim=-1)
3839+
probs = F.softmax(pde_logits, dim=-1)
38673840

38683841
pde = einsum(probs, bin_centers, "b i j pde, pde -> b i j")
38693842

@@ -5968,7 +5941,6 @@ def forward(
59685941
pred_frames,
59695942
frames,
59705943
mask=align_error_mask,
5971-
molecule_atom_lens=molecule_atom_lens,
59725944
)
59735945

59745946
# calculate pae labels as alignment error binned to 64 (0 - 32A) (TODO: double-check correctness of `distance_to_bins`'s bin assignments)

tests/test_af3.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,17 +1066,11 @@ def test_compute_ranking_score():
10661066
pae_logits, asym_id, has_frame
10671067
)
10681068

1069-
atom_level_ptm_score = compute_ranking_score.compute_confidence_score.compute_ptm(
1070-
atom_level_pae_logits, asym_id, has_frame,
1071-
molecule_atom_lens=molecule_atom_lens
1072-
)
1073-
10741069
assert full_complex_metric.numel() == batch_size
10751070
assert single_chain_metric.numel() == batch_size
10761071
assert interface_metric.numel() == batch_size
10771072
assert modified_residue_score.numel() == batch_size
10781073
assert residue_level_ptm_score.numel() == batch_size
1079-
assert atom_level_ptm_score.numel() == batch_size
10801074

10811075
def test_model_selection_score():
10821076

0 commit comments

Comments
 (0)