Skip to content

Commit 8f22c88

Browse files
committed
add some typecheck decorators, linting
1 parent 6b25e7e commit 8f22c88

File tree

4 files changed

+34
-12
lines changed

4 files changed

+34
-12
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1145,7 +1145,7 @@ def __init__(
11451145
def forward(
11461146
self,
11471147
*,
1148-
additional_molecule_feats: Float[f'b n {ADDITIONAL_molecule_FEATS}']
1148+
additional_molecule_feats: Float[f'b n {ADDITIONAL_MOLECULE_FEATS}']
11491149
) -> Float['b n n dp']:
11501150

11511151
device = additional_molecule_feats.device

alphafold3_pytorch/pdb_dataset_curation.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,12 @@
2929
#
3030

3131
# %%
32+
from __future__ import annotations
3233
import argparse
3334
import glob
3435
import os
3536
import random
36-
from typing import Any, Dict, List, Optional, Set, Tuple, Union
37+
from typing import Any, Dict, List, Set, Tuple
3738

3839
import pandas as pd
3940
from Bio.PDB import MMCIFIO, PDBIO, MMCIFParser, PDBParser
@@ -45,6 +46,8 @@
4546
from pdbeccdutils.core.ccd_reader import CCDReaderResult
4647
from tqdm.contrib.concurrent import process_map
4748

49+
from alphafold3_pytorch.typing import typecheck
50+
4851
# Parse command-line arguments
4952

5053
parser = argparse.ArgumentParser(
@@ -99,7 +102,7 @@
99102

100103
# Constants
101104

102-
Token = Union[Residue, Atom]
105+
Token = Residue | Atom
103106

104107
# Section 2.5.4 of the AlphaFold 3 supplement
105108

@@ -184,6 +187,7 @@ def exists(v: Any) -> bool:
184187
return v is not None
185188

186189

190+
@typecheck
187191
def parse_structure(filepath: str) -> Structure:
188192
"""Parse a structure from a PDB or mmCIF file."""
189193
if filepath.endswith(".pdb"):
@@ -196,7 +200,7 @@ def parse_structure(filepath: str) -> Structure:
196200
structure = parser.get_structure(structure_id, filepath)
197201
return structure
198202

199-
203+
@typecheck
200204
def filter_pdb_deposition_date(
201205
structure: Structure, cutoff_date: pd.Timestamp = pd.to_datetime("2021-09-30")
202206
) -> bool:
@@ -210,17 +214,17 @@ def filter_pdb_deposition_date(
210214
return False
211215

212216

217+
@typecheck
213218
def filter_resolution(structure: Structure, max_resolution: float = 9.0) -> bool:
214219
"""Filter based on resolution."""
215-
if (
220+
return (
216221
"resolution" in structure.header
217222
and exists(structure.header["resolution"])
218223
and structure.header["resolution"] <= max_resolution
219-
):
220-
return True
221-
return False
224+
)
222225

223226

227+
@typecheck
224228
def filter_polymer_chains(
225229
structure: Structure, max_chains: int = 1000, for_training: bool = False
226230
) -> bool:
@@ -229,6 +233,7 @@ def filter_polymer_chains(
229233
return count <= (300 if for_training else max_chains)
230234

231235

236+
@typecheck
232237
def filter_resolved_chains(structure: Structure) -> Structure:
233238
"""Filter based on number of resolved residues."""
234239
chains_to_remove = [
@@ -243,7 +248,8 @@ def filter_resolved_chains(structure: Structure) -> Structure:
243248
return structure if list(structure.get_chains()) else None
244249

245250

246-
def filter_target(structure: Structure) -> Optional[Structure]:
251+
@typecheck
252+
def filter_target(structure: Structure) -> Structure | None:
247253
"""Filter a target based on various criteria."""
248254
target_passes_prefilters = (
249255
filter_pdb_deposition_date(structure)
@@ -253,6 +259,7 @@ def filter_target(structure: Structure) -> Optional[Structure]:
253259
return filter_resolved_chains(structure) if target_passes_prefilters else None
254260

255261

262+
@typecheck
256263
def remove_hydrogens(structure: Structure, remove_waters: bool = True) -> Structure:
257264
"""
258265
Remove hydrogens (and optionally waters) from a structure.
@@ -284,6 +291,7 @@ def remove_hydrogens(structure: Structure, remove_waters: bool = True) -> Struct
284291
return structure
285292

286293

294+
@typecheck
287295
def remove_all_unknown_residue_chains(
288296
structure: Structure, standard_residues: Set[str]
289297
) -> Structure:
@@ -300,6 +308,7 @@ def remove_all_unknown_residue_chains(
300308
return structure
301309

302310

311+
@typecheck
303312
def remove_clashing_chains(
304313
structure: Structure, clash_threshold: float = 1.7, clash_percentage: float = 0.3
305314
) -> Structure:
@@ -360,6 +369,7 @@ def remove_clashing_chains(
360369
return structure
361370

362371

372+
@typecheck
363373
def remove_excluded_ligands(structure: Structure, ligand_exclusion_list: Set[str]) -> Structure:
364374
"""
365375
Remove ligands in the exclusion list.
@@ -385,6 +395,7 @@ def remove_excluded_ligands(structure: Structure, ligand_exclusion_list: Set[str
385395
return structure
386396

387397

398+
@typecheck
388399
def remove_non_ccd_atoms(
389400
structure: Structure, ccd_reader_results: Dict[str, CCDReaderResult]
390401
) -> Structure:
@@ -417,6 +428,7 @@ def remove_non_ccd_atoms(
417428
return structure
418429

419430

431+
@typecheck
420432
def is_covalently_bonded(atom1: Atom, atom2: Atom) -> bool:
421433
"""
422434
Check if two atoms are covalently bonded.
@@ -431,6 +443,7 @@ def is_covalently_bonded(atom1: Atom, atom2: Atom) -> bool:
431443
return False
432444

433445

446+
@typecheck
434447
def remove_leaving_atoms(
435448
structure: Structure, ccd_reader_results: Dict[str, CCDReaderResult]
436449
) -> Structure:
@@ -510,6 +523,7 @@ def remove_leaving_atoms(
510523
return structure
511524

512525

526+
@typecheck
513527
def filter_large_ca_distances(structure: Structure) -> Structure:
514528
"""
515529
Filter chains with large Ca atom distances.
@@ -533,6 +547,7 @@ def filter_large_ca_distances(structure: Structure) -> Structure:
533547
return structure
534548

535549

550+
@typecheck
536551
def select_closest_chains(
537552
structure: Structure,
538553
protein_residue_center_atoms: Dict[str, str],
@@ -541,6 +556,7 @@ def select_closest_chains(
541556
) -> Structure:
542557
"""Select the closest chains in large bioassemblies."""
543558

559+
@typecheck
544560
def get_tokens_from_residues(
545561
residues: List[Residue],
546562
protein_residue_center_atoms: Dict[str, str],
@@ -559,6 +575,7 @@ def get_tokens_from_residues(
559575
tokens.append(atom)
560576
return tokens
561577

578+
@typecheck
562579
def get_token_center_atom(
563580
token: Token,
564581
protein_residue_center_atoms: Dict[str, str],
@@ -574,6 +591,7 @@ def get_token_center_atom(
574591
token_center_atom = token
575592
return token_center_atom
576593

594+
@typecheck
577595
def get_token_center_atoms(
578596
tokens: List[Token],
579597
protein_residue_center_atoms: Dict[str, str],
@@ -588,6 +606,7 @@ def get_token_center_atoms(
588606
token_center_atoms.append(token_center_atom)
589607
return token_center_atoms
590608

609+
@typecheck
591610
def get_interface_tokens(
592611
tokens: List[Token],
593612
protein_residue_center_atoms: Dict[str, str],
@@ -645,6 +664,7 @@ def get_interface_tokens(
645664
return structure
646665

647666

667+
@typecheck
648668
def remove_crystallization_aids(
649669
structure: Structure, crystallography_methods: Dict[str, Set[str]]
650670
) -> Structure:
@@ -672,6 +692,7 @@ def remove_crystallization_aids(
672692
return structure
673693

674694

695+
@typecheck
675696
def write_structure(structure: Structure, output_filepath: str):
676697
"""Write a structure to a PDB or mmCIF file."""
677698
if output_filepath.endswith(".pdb"):
@@ -684,6 +705,7 @@ def write_structure(structure: Structure, output_filepath: str):
684705
io.save(output_filepath)
685706

686707

708+
@typecheck
687709
def process_structure(args: Tuple[str, str, bool]):
688710
"""
689711
Given an input mmCIF file, create a new processed mmCIF file

alphafold3_pytorch/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from __future__ import annotations
22

3-
from functools import wraps
3+
from functools import wraps, partial
44
from pathlib import Path
55

66
from alphafold3_pytorch.alphafold3 import Alphafold3
77
from alphafold3_pytorch.attention import pad_at_dim
88

9-
from typing import TypedDict, List
9+
from typing import TypedDict, List, Callable
1010

1111
from alphafold3_pytorch.typing import (
1212
typecheck,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.12"
3+
version = "0.1.14"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)