Skip to content

Commit 976aec4

Browse files
committed
fix types in inputs
1 parent 8baeb9a commit 976aec4

File tree

3 files changed

+13
-13
lines changed

3 files changed

+13
-13
lines changed

alphafold3_pytorch/cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def cli(
3636
assert checkpoint_path.exists(), f'AlphaFold 3 checkpoint must exist at {str(checkpoint_path)}'
3737

3838
alphafold3_input = Alphafold3Input(
39-
proteins = protein,
40-
ss_rna = rna,
41-
ss_dna = dna,
39+
proteins = list(protein),
40+
ss_rna = list(rna),
41+
ss_dna = list(dna),
4242
)
4343

4444
alphafold3 = Alphafold3.init_and_load(checkpoint_path)

alphafold3_pytorch/inputs.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ class BatchedAtomInput:
503503
resolution: Float[" b"] | None = None # type: ignore
504504
token_constraints: Float["b n n dac"] | None = None # type: ignore
505505
chains: Int["b 2"] | None = None # type: ignore
506-
filepath: List[str] | None = None
506+
filepath: List[str] | Tuple[str, ...] | None = None
507507

508508
def dict(self):
509509
"""Return the dataclass as a dictionary."""
@@ -736,11 +736,11 @@ class MoleculeInput:
736736
molecule_ids: Int[" n"] # type: ignore
737737
additional_molecule_feats: Int[f"n {ADDITIONAL_MOLECULE_FEATS}"] # type: ignore
738738
is_molecule_types: Bool[f"n {IS_MOLECULE_TYPES}"] # type: ignore
739-
src_tgt_atom_indices: Int["n 2"] # type: ignore
739+
src_tgt_atom_indices: Int["n 2"] | List[List[int]] # type: ignore
740740
token_bonds: Bool["n n"] # type: ignore
741741
is_molecule_mod: Bool["n num_mods"] | Bool[" n"] | None = None # type: ignore
742-
molecule_atom_indices: List[int | None] | None = None # type: ignore
743-
distogram_atom_indices: List[int | None] | None = None # type: ignore
742+
molecule_atom_indices: List[int | None] | Int[" n"] | None = None # type: ignore
743+
distogram_atom_indices: List[int | None] | Int[" n"] | None = None # type: ignore
744744
atom_indices_for_frame: Int["n 3"] | None = None # type: ignore
745745
missing_atom_indices: List[Int[" _"] | None] | None = None # type: ignore
746746
missing_token_indices: List[Int[" _"] | None] | None = None # type: ignore
@@ -1085,8 +1085,8 @@ class MoleculeLengthMoleculeInput:
10851085
token_bonds: Bool["n n"] | None = None # type: ignore
10861086
one_token_per_atom: List[bool] | None = None
10871087
is_molecule_mod: Bool["n num_mods"] | Bool[" n"] | None = None # type: ignore
1088-
molecule_atom_indices: List[int | None] | None = None
1089-
distogram_atom_indices: List[int | None] | None = None
1088+
molecule_atom_indices: List[int | None] | Int[" n"] | None = None
1089+
distogram_atom_indices: List[int | None] | Int[" n"] | None = None
10901090
atom_indices_for_frame: List[Tuple[int, int, int] | None] | None = None
10911091
missing_atom_indices: List[Int[" _"] | None] | None = None # type: ignore
10921092
missing_token_indices: List[Int[" _"] | None] | None = None # type: ignore
@@ -2178,8 +2178,8 @@ class PDBInput:
21782178
directed_bonds: bool = False
21792179
custom_atoms: List[str] | None = None
21802180
custom_bonds: List[str] | None = None
2181-
training: bool = False
2182-
inference: bool = False
2181+
training: bool | None = None
2182+
inference: bool | None = None
21832183
distillation: bool = False
21842184
distillation_multimer_sampling_ratio: float = 2.0 / 3.0
21852185
distillation_pdb_ids: List[str] | None = None

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.7.8"
3+
version = "0.7.9"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },
@@ -47,7 +47,7 @@ dependencies = [
4747
"huggingface_hub>=0.21.4",
4848
"jaxtyping>=0.2.28",
4949
"lightning>=2.2.5",
50-
"multimolecule",
50+
"multimolecule==0.0.5",
5151
"nimporter",
5252
"numpy>=1.23.5",
5353
"polars>=1.1.0",

0 commit comments

Comments
 (0)