Skip to content

Commit 63890bd

Browse files
committed
improve some types and typechecking in computing score
1 parent 49018a6 commit 63890bd

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3827,12 +3827,13 @@ def __init__(
38273827
self.chain_clash_count = chain_clash_count
38283828
self.chain_clash_ratio = chain_clash_ratio
38293829

3830+
@typecheck
38303831
def compute_has_clash(
38313832
self,
38323833
atom_pos: Float["m 3"],
38333834
asym_id: Int[" n"],
38343835
indices: Int[" m"],
3835-
valid_indices: Int[" m"],
3836+
valid_indices: Bool[" m"],
38363837
) -> Bool[""]:
38373838
"""Compute if there is a clash in the chain.
38383839
@@ -3872,13 +3873,15 @@ def compute_has_clash(
38723873

38733874
return torch.tensor(False, dtype=torch.bool, device=atom_pos.device)
38743875

3876+
@typecheck
38753877
def forward(
38763878
self,
38773879
atom_pos: Float["b m 3"] | Float["m 3"],
3878-
atom_mask: Bool["b m"] | Bool[" m"],
3880+
atom_mask: Bool["b m"] | Bool[" m"],
38793881
molecule_atom_lens: Int["b n"] | Int[" n"],
38803882
asym_id: Int["b n"] | Int[" n"],
3881-
) -> Bool[""]:
3883+
) -> Bool[" b"]:
3884+
38823885
"""Compute if there is a clash in the chain.
38833886
38843887
:param atom_pos: [b m 3] atom positions
@@ -3938,11 +3941,12 @@ def __init__(
39383941
self.score_ptm_weight = score_ptm_weight
39393942
self.score_disorder_weight = score_disorder_weight
39403943

3944+
@typecheck
39413945
def compute_disorder(
39423946
self,
39433947
plddt: Float["b m"],
3944-
atom_mask: Float["b m"],
3945-
atom_is_molecule_types: Float["b m"],
3948+
atom_mask: Bool["b m"],
3949+
atom_is_molecule_types: Bool[f"b m {IS_MOLECULE_TYPES}"],
39463950
) -> Float[" b"]:
39473951
"""Compute disorder score.
39483952
@@ -3959,6 +3963,7 @@ def compute_disorder(
39593963
disorder = ((atom_rasa > 0.581) * mask).sum(dim=-1) / (self.eps + mask.sum(dim=1))
39603964
return disorder
39613965

3966+
@typecheck
39623967
def compute_full_complex_metric(
39633968
self,
39643969
confidence_head_logits: ConfidenceHeadLogits,
@@ -3967,7 +3972,7 @@ def compute_full_complex_metric(
39673972
molecule_atom_lens: Int["b n"],
39683973
atom_pos: Float["b m 3"],
39693974
atom_mask: Bool["b m"],
3970-
is_molecule_types: Int[f"b n {IS_MOLECULE_TYPES}"],
3975+
is_molecule_types: Bool[f"b n {IS_MOLECULE_TYPES}"],
39713976
return_confidence_score: bool = False,
39723977
) -> Float[" b"] | Tuple[Float[" b"], Tuple[ConfidenceScore, Bool[" b"]]]:
39733978
"""Compute full complex metric.
@@ -4028,12 +4033,14 @@ def compute_full_complex_metric(
40284033

40294034
return weighted_score, (confidence_score, has_clash)
40304035

4036+
@typecheck
40314037
def compute_single_chain_metric(
40324038
self,
40334039
confidence_head_logits: ConfidenceHeadLogits,
40344040
asym_id: Int["b n"],
40354041
has_frame: Bool["b n"],
4036-
) -> Float[" b"]:
4042+
) -> Float[" b"]:
4043+
40374044
"""Compute single chain metric.
40384045
40394046
:param confidence_head_logits: ConfidenceHeadLogits
@@ -4051,6 +4058,7 @@ def compute_single_chain_metric(
40514058
score = confidence_score.ptm
40524059
return score
40534060

4061+
@typecheck
40544062
def compute_interface_metric(
40554063
self,
40564064
confidence_head_logits: ConfidenceHeadLogits,
@@ -4104,6 +4112,7 @@ def compute_interface_metric(
41044112
interface_metric[b] /= len(chains)
41054113
return interface_metric
41064114

4115+
@typecheck
41074116
def compute_modified_residue_score(
41084117
self,
41094118
confidence_head_logits: ConfidenceHeadLogits,
@@ -4132,7 +4141,6 @@ def compute_modified_residue_score(
41324141

41334142
# model selection
41344143

4135-
41364144
@typecheck
41374145
def get_cid_molecule_type(
41384146
cid: int,

tests/test_af3.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,9 +1036,6 @@ def test_alphafold3_config():
10361036
# test compute ranking score
10371037

10381038
def test_compute_ranking_score():
1039-
1040-
import random
1041-
import itertools
10421039

10431040
# mock inputs
10441041

0 commit comments

Comments
 (0)