Skip to content

Commit ebcac44

Browse files
committed
last cleanup for the day
1 parent 17ab525 commit ebcac44

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3309,13 +3309,20 @@ class ComputeRankingScore(Module):
33093309

33103310
def __init__(
33113311
self,
3312-
eps = 1e-8
3312+
eps = 1e-8,
3313+
score_iptm_weight = 0.8,
3314+
score_ptm_weight = 0.2,
3315+
score_disorder_weight = 0.5
33133316
):
33143317
super().__init__()
33153318
self.eps = eps
33163319
self.compute_clash = ComputeClash()
33173320
self.compute_confidence_score = ComputeConfidenceScore(eps=eps)
33183321

3322+
self.score_iptm_weight = score_iptm_weight
3323+
self.score_ptm_weight = score_ptm_weight
3324+
self.score_disorder_weight = score_disorder_weight
3325+
33193326
@typecheck
33203327
def compute_disorder(
33213328
self,
@@ -3342,7 +3349,8 @@ def compute_full_complex_metric(
33423349
atom_pos: Float['b m 3'],
33433350
atom_mask: Bool['b m'],
33443351
is_molecule_types: Bool[f'b n {IS_MOLECULE_TYPES}'],
3345-
) -> Float[' b']:
3352+
return_confidence_score: bool = False
3353+
) -> Float[' b'] | Tuple[Float[' b'], Tuple[ConfidenceScore, Bool[' b']]]:
33463354

33473355
# Section 5.9.3.1
33483356

@@ -3372,9 +3380,18 @@ def compute_full_complex_metric(
33723380
disorder = self.compute_disorder(confidence_score.plddt, atom_mask, atom_is_molecule_types)
33733381

33743382
# Section 5.9.3 equation 19
3375-
score = 0.8 * confidence_score.iptm + 0.2 * confidence_score.ptm + 0.5 * disorder - 100 * has_clash
33763383

3377-
return score
3384+
weighted_score = (
3385+
confidence_score.iptm * self.score_iptm_weight +
3386+
confidence_score.ptm * self.score_ptm_weight +
3387+
disorder * self.score_disorder_weight
3388+
- 100 * has_clash
3389+
)
3390+
3391+
if not return_confidence_score:
3392+
return weighted_score
3393+
3394+
return weighted_score, (confidence_score, has_clash)
33783395

33793396
@typecheck
33803397
def compute_single_chain_metric(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.45"
3+
version = "0.2.46"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)