Skip to content

Commit 2694184

Browse files
committed
fix types
1 parent c461701 commit 2694184

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3492,12 +3492,14 @@ def compute_modified_residue_score(
34923492
return plddt_mean
34933493

34943494
# model selection
3495+
3496+
@typecheck
34953497
def get_cid_molecule_type(
34963498
cid: int,
34973499
asym_id: Int[' n'],
3498-
is_molecule_types: Bool['n {IS_MOLECULE_TYPES}'],
3500+
is_molecule_types: Bool[f'n {IS_MOLECULE_TYPES}'],
34993501
return_one_hot: bool = False,
3500-
) -> int | Bool[' {IS_MOLECULE_TYPES}']:
3502+
) -> int | Bool[f' {IS_MOLECULE_TYPES}']:
35013503
"""
35023504
35033505
get the molecule type for where asym_id == cid
@@ -3517,6 +3519,7 @@ def get_cid_molecule_type(
35173519

35183520
class ComputeModelSelectionScore(Module):
35193521

3522+
@typecheck
35203523
def __init__(
35213524
self,
35223525
eps: float = 1e-8,
@@ -3535,6 +3538,7 @@ def __init__(
35353538

35363539
self.register_buffer('dist_breaks', dist_breaks)
35373540

3541+
@typecheck
35383542
def compute_gpde(
35393543
self,
35403544
pde_logits: Float['b pde n n'],
@@ -3570,6 +3574,7 @@ def compute_gpde(
35703574

35713575
return gpde
35723576

3577+
@typecheck
35733578
def compute_lddt(
35743579
self,
35753580
pred_coords: Float['b m 3'],
@@ -3628,13 +3633,14 @@ def compute_lddt(
36283633

36293634
return lddt_mean
36303635

3636+
@typecheck
36313637
def compute_chain_pair_lddt(
36323638
self,
36333639
asym_mask_a: Bool['b m'] | Bool [' m'],
36343640
asym_mask_b: Bool['b m'] | Bool [' m'],
36353641
pred_coords: Float['b m 3'] | Float['m 3'],
36363642
true_coords: Float['b m 3'] | Float['m 3'],
3637-
is_molecule_types: Int['b m {IS_MOLECULE_TYPES}'] | Int['m {IS_MOLECULE_TYPES}'],
3643+
is_molecule_types: Bool[f'b m {IS_MOLECULE_TYPES}'] | Bool[f'm {IS_MOLECULE_TYPES}'],
36383644
coords_mask: Bool['b m'] | Bool [' m'] | None = None,
36393645
) -> Float[' b']:
36403646
"""
@@ -3665,6 +3671,7 @@ def compute_chain_pair_lddt(
36653671

36663672
return lddt
36673673

3674+
@typecheck
36683675
def get_lddt_weight(
36693676
self,
36703677
type_chain_a,
@@ -3728,6 +3735,7 @@ def get_lddt_weight(
37283735
assert weight, f"Weight not found for {interface_type} {lddt_type}"
37293736
return weight
37303737

3738+
@typecheck
37313739
def compute_weighted_lddt(
37323740
self,
37333741
# atom level input
@@ -3736,7 +3744,7 @@ def compute_weighted_lddt(
37363744
atom_mask: Bool['b m'] | None,
37373745
# token level input
37383746
asym_id: Int['b n'],
3739-
is_molecule_types: Bool['b n {IS_MOLECULE_TYPES}'],
3747+
is_molecule_types: Bool[f'b n {IS_MOLECULE_TYPES}'],
37403748
molecule_atom_lens: Int['b n'],
37413749
# additional input
37423750
chains_list: List[Tuple[int, int] | Tuple[int]],

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.51"
3+
version = "0.2.52"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)