Skip to content

Commit 11f589e

Browse files
committed
one more cleanup
1 parent 63890bd commit 11f589e

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3683,7 +3683,8 @@ def compute_ptm(
36833683
molecule_atom_lens: Int["b n"] | None = None,
36843684
interface: bool = False,
36853685
compute_chain_wise_iptm: bool = False,
3686-
):
3686+
) -> Float[" b"] | Tuple[Float["b chains chains"], Bool["b chains chains"], Int["b chains"]]:
3687+
36873688
"""Compute pTM from logits.
36883689
36893690
:param logits: [b c n n] logits
@@ -3759,22 +3760,22 @@ def compute_ptm(
37593760
if pair_residue_weights.sum() == 0:
37603761
# chain i or chain j does not have any valid frame
37613762
continue
3762-
else:
3763-
normed_residue_mask = pair_residue_weights / (
3764-
self.eps
3765-
+ torch.sum(pair_residue_weights, dim=-1, keepdims=True)
3766-
)
37673763

3768-
masked_predicted_tm_term = predicted_tm_term[b] * pair_mask
3764+
normed_residue_mask = pair_residue_weights / (
3765+
self.eps
3766+
+ torch.sum(pair_residue_weights, dim=-1, keepdims=True)
3767+
)
3768+
3769+
masked_predicted_tm_term = predicted_tm_term[b] * pair_mask
37693770

3770-
per_alignment = torch.sum(
3771-
masked_predicted_tm_term * normed_residue_mask, dim=-1
3772-
)
3773-
weighted_argmax = (residue_weights[b] * per_alignment).argmax()
3774-
chain_wise_iptm[b, i, j] = per_alignment[weighted_argmax]
3775-
chain_wise_iptm_mask[b, i, j] = True
3771+
per_alignment = torch.sum(
3772+
masked_predicted_tm_term * normed_residue_mask, dim=-1
3773+
)
3774+
weighted_argmax = (residue_weights[b] * per_alignment).argmax()
3775+
chain_wise_iptm[b, i, j] = per_alignment[weighted_argmax]
3776+
chain_wise_iptm_mask[b, i, j] = True
37763777

3777-
return chain_wise_iptm, chain_wise_iptm_mask, unique_chains
3778+
return chain_wise_iptm, chain_wise_iptm_mask, torch.tensor(unique_chains)
37783779

37793780
else:
37803781
pair_mask = torch.ones(size=(num_batch, num_res, num_res), device=logits.device).bool()
@@ -3974,7 +3975,8 @@ def compute_full_complex_metric(
39743975
atom_mask: Bool["b m"],
39753976
is_molecule_types: Bool[f"b n {IS_MOLECULE_TYPES}"],
39763977
return_confidence_score: bool = False,
3977-
) -> Float[" b"] | Tuple[Float[" b"], Tuple[ConfidenceScore, Bool[" b"]]]:
3978+
) -> Float[" b"] | Tuple[Float[" b"], Tuple[ConfidenceScore, Bool[" b"]]]:
3979+
39783980
"""Compute full complex metric.
39793981
39803982
:param confidence_head_logits: ConfidenceHeadLogits
@@ -4105,7 +4107,7 @@ def compute_interface_metric(
41054107

41064108
for b, chains in enumerate(interface_chains):
41074109
for chain in chains:
4108-
idx = unique_chains[b].index(chain)
4110+
idx = unique_chains[b].tolist().index(chain)
41094111
interface_metric[b] += iptm_sum[b, idx].sum() / iptm_count[b, idx].sum().clamp(
41104112
min=1
41114113
)

0 commit comments

Comments
 (0)