@@ -3827,12 +3827,13 @@ def __init__(
38273827 self .chain_clash_count = chain_clash_count
38283828 self .chain_clash_ratio = chain_clash_ratio
38293829
3830+ @typecheck
38303831 def compute_has_clash (
38313832 self ,
38323833 atom_pos : Float ["m 3" ],
38333834 asym_id : Int [" n" ],
38343835 indices : Int [" m" ],
3835- valid_indices : Int [" m" ],
3836+ valid_indices : Bool [" m" ],
38363837 ) -> Bool ["" ]:
38373838 """Compute if there is a clash in the chain.
38383839
@@ -3872,13 +3873,15 @@ def compute_has_clash(
38723873
38733874 return torch .tensor (False , dtype = torch .bool , device = atom_pos .device )
38743875
3876+ @typecheck
38753877 def forward (
38763878 self ,
38773879 atom_pos : Float ["b m 3" ] | Float ["m 3" ],
3878- atom_mask : Bool ["b m" ] | Bool [" m" ],
3880+ atom_mask : Bool ["b m" ] | Bool [" m" ],
38793881 molecule_atom_lens : Int ["b n" ] | Int [" n" ],
38803882 asym_id : Int ["b n" ] | Int [" n" ],
3881- ) -> Bool ["" ]:
3883+ ) -> Bool [" b" ]:
3884+
38823885 """Compute if there is a clash in the chain.
38833886
38843887 :param atom_pos: [b m 3] atom positions
@@ -3938,11 +3941,12 @@ def __init__(
39383941 self .score_ptm_weight = score_ptm_weight
39393942 self .score_disorder_weight = score_disorder_weight
39403943
3944+ @typecheck
39413945 def compute_disorder (
39423946 self ,
39433947 plddt : Float ["b m" ],
3944- atom_mask : Float ["b m" ],
3945- atom_is_molecule_types : Float [ "b m" ],
3948+ atom_mask : Bool ["b m" ],
3949+ atom_is_molecule_types : Bool [ f "b m { IS_MOLECULE_TYPES } " ],
39463950 ) -> Float [" b" ]:
39473951 """Compute disorder score.
39483952
@@ -3959,6 +3963,7 @@ def compute_disorder(
39593963 disorder = ((atom_rasa > 0.581 ) * mask ).sum (dim = - 1 ) / (self .eps + mask .sum (dim = 1 ))
39603964 return disorder
39613965
3966+ @typecheck
39623967 def compute_full_complex_metric (
39633968 self ,
39643969 confidence_head_logits : ConfidenceHeadLogits ,
@@ -3967,7 +3972,7 @@ def compute_full_complex_metric(
39673972 molecule_atom_lens : Int ["b n" ],
39683973 atom_pos : Float ["b m 3" ],
39693974 atom_mask : Bool ["b m" ],
3970- is_molecule_types : Int [f"b n { IS_MOLECULE_TYPES } " ],
3975+ is_molecule_types : Bool [f"b n { IS_MOLECULE_TYPES } " ],
39713976 return_confidence_score : bool = False ,
39723977 ) -> Float [" b" ] | Tuple [Float [" b" ], Tuple [ConfidenceScore , Bool [" b" ]]]:
39733978 """Compute full complex metric.
@@ -4028,12 +4033,14 @@ def compute_full_complex_metric(
40284033
40294034 return weighted_score , (confidence_score , has_clash )
40304035
4036+ @typecheck
40314037 def compute_single_chain_metric (
40324038 self ,
40334039 confidence_head_logits : ConfidenceHeadLogits ,
40344040 asym_id : Int ["b n" ],
40354041 has_frame : Bool ["b n" ],
4036- ) -> Float [" b" ]:
4042+ ) -> Float [" b" ]:
4043+
40374044 """Compute single chain metric.
40384045
40394046 :param confidence_head_logits: ConfidenceHeadLogits
@@ -4051,6 +4058,7 @@ def compute_single_chain_metric(
40514058 score = confidence_score .ptm
40524059 return score
40534060
4061+ @typecheck
40544062 def compute_interface_metric (
40554063 self ,
40564064 confidence_head_logits : ConfidenceHeadLogits ,
@@ -4104,6 +4112,7 @@ def compute_interface_metric(
41044112 interface_metric [b ] /= len (chains )
41054113 return interface_metric
41064114
4115+ @typecheck
41074116 def compute_modified_residue_score (
41084117 self ,
41094118 confidence_head_logits : ConfidenceHeadLogits ,
@@ -4132,7 +4141,6 @@ def compute_modified_residue_score(
41324141
41334142# model selection
41344143
4135-
41364144@typecheck
41374145def get_cid_molecule_type (
41384146 cid : int ,
0 commit comments