11from functools import wraps
2- from typing import Type , TypedDict , Literal , Callable , List
2+ from dataclasses import dataclass
3+ from typing import Type , Literal , Callable , List , Any
34
45from rdkit import Chem
56from rdkit .Chem .rdchem import Mol
67
78from alphafold3_pytorch .tensor_typing import (
89 typecheck ,
10+ beartype_isinstance ,
911 Int , Bool , Float
1012)
1113
2123IS_MOLECULE_TYPES = 4
2224ADDITIONAL_MOLECULE_FEATS = 5
2325
24- # simple compose function
25- # for chaining from Alphafold3Input -> MoleculeInput -> AtomInput
26+ # functions
27+
28+ def exists (v ):
29+ return v is not None
30+
31+ def identity (t ):
32+ return t
2633
2734def compose (* fns : Callable ):
35+ # for chaining from Alphafold3Input -> MoleculeInput -> AtomInput
36+
2837 def inner (x , * args , ** kwargs ):
2938 for fn in fns :
3039 x = fn (x , * args , ** kwargs )
@@ -34,30 +43,33 @@ def inner(x, *args, **kwargs):
3443# atom level, what Alphafold3 accepts
3544
3645@typecheck
37- class AtomInput (TypedDict ):
46+ @dataclass
47+ class AtomInput :
3848 atom_inputs : Float ['m dai' ]
39- molecule_ids : Int ['n' ]
40- molecule_atom_lens : Int ['n' ]
49+ molecule_ids : Int [' n' ]
50+ molecule_atom_lens : Int [' n' ]
4151 atompair_inputs : Float ['m m dapi' ] | Float ['nw w (w*2) dapi' ]
4252 additional_molecule_feats : Float [f'n { ADDITIONAL_MOLECULE_FEATS } ' ]
4353 is_molecule_types : Bool [f'n { IS_MOLECULE_TYPES } ' ]
4454 templates : Float ['t n n dt' ]
4555 msa : Float ['s n dm' ]
46- token_bonds : Bool ['n n' ] | None
47- atom_ids : Int ['m' ] | None
48- atom_parent_ids : Int ['m' ] | None
49- atompair_ids : Int ['m m' ] | Int ['nw w (w*2)' ] | None
50- template_mask : Bool ['t' ] | None
51- msa_mask : Bool ['s' ] | None
52- atom_pos : Float ['m 3' ] | None
53- molecule_atom_indices : Int ['n' ] | None
54- distance_labels : Int ['n n' ] | None
55- pae_labels : Int ['n n' ] | None
56- pde_labels : Int ['n' ] | None
57- resolved_labels : Int ['n' ] | None
56+ token_bonds : Bool ['n n' ] | None = None
57+ atom_ids : Int [' m' ] | None = None
58+ atom_parent_ids : Int [' m' ] | None = None
59+ atompair_ids : Int ['m m' ] | Int ['nw w (w*2)' ] | None = None
60+ template_mask : Bool [' t' ] | None = None
61+ msa_mask : Bool [' s' ] | None = None
62+ atom_pos : Float ['m 3' ] | None = None
63+ molecule_atom_indices : Int [' n' ] | None = None
64+ distance_labels : Int ['n n' ] | None = None
65+ pae_labels : Int ['n n' ] | None = None
66+ pde_labels : Int ['n n' ] | None = None
67+ plddt_labels : Int [' n' ] | None = None
68+ resolved_labels : Int [' n' ] | None = None
5869
5970@typecheck
60- class BatchedAtomInput (TypedDict ):
71+ @dataclass
72+ class BatchedAtomInput :
6173 atom_inputs : Float ['b m dai' ]
6274 molecule_ids : Int ['b n' ]
6375 molecule_atom_lens : Int ['b n' ]
@@ -66,38 +78,40 @@ class BatchedAtomInput(TypedDict):
6678 is_molecule_types : Bool [f'b n { IS_MOLECULE_TYPES } ' ]
6779 templates : Float ['b t n n dt' ]
6880 msa : Float ['b s n dm' ]
69- token_bonds : Bool ['b n n' ] | None
70- atom_ids : Int ['b m' ] | None
71- atom_parent_ids : Int ['b m' ] | None
72- atompair_ids : Int ['b m m' ] | Int ['b nw w (w*2)' ] | None
73- template_mask : Bool ['b t' ] | None
74- msa_mask : Bool ['b s' ] | None
75- atom_pos : Float ['b m 3' ] | None
76- molecule_atom_indices : Int ['b n' ] | None
77- distance_labels : Int ['b n n' ] | None
78- pae_labels : Int ['b n n' ] | None
79- pde_labels : Int ['b n' ] | None
80- resolved_labels : Int ['b n' ] | None
81+ token_bonds : Bool ['b n n' ] | None = None
82+ atom_ids : Int ['b m' ] | None = None
83+ atom_parent_ids : Int ['b m' ] | None = None
84+ atompair_ids : Int ['b m m' ] | Int ['b nw w (w*2)' ] | None = None
85+ template_mask : Bool ['b t' ] | None = None
86+ msa_mask : Bool ['b s' ] | None = None
87+ atom_pos : Float ['b m 3' ] | None = None
88+ molecule_atom_indices : Int ['b n' ] | None = None
89+ distance_labels : Int ['b n n' ] | None = None
90+ pae_labels : Int ['b n n' ] | None = None
91+ pde_labels : Int ['b n n' ] | None = None
92+ plddt_labels : Int ['b n' ] | None = None
93+ resolved_labels : Int ['b n' ] | None = None
8194
8295# molecule input - accepting list of molecules as rdchem.Mol + the atomic lengths for how to pool into tokens
8396
8497@typecheck
85- class MoleculeInput (TypedDict ):
98+ @dataclass
99+ class MoleculeInput :
86100 molecules : List [Mol ]
87101 molecule_token_pool_lens : List [List [int ]]
88102 molecule_atom_indices : List [List [int ] | None ]
89- molecule_ids : Int ['n' ]
103+ molecule_ids : Int [' n' ]
90104 additional_molecule_feats : Float ['n 5' ]
91105 is_molecule_types : Bool ['n 4' ]
92- atom_pos : List [Float ['_ 3' ]] | Float ['m 3' ] | None
106+ atom_pos : List [Float ['_ 3' ]] | Float ['m 3' ] | None = None
93107 templates : Float ['t n n dt' ]
94- template_mask : Bool ['t' ] | None
108+ template_mask : Bool [' t' ] | None = None
95109 msa : Float ['s n dm' ]
96- msa_mask : Bool ['s' ] | None
97- distance_labels : Int ['n n' ] | None
98- pae_labels : Int ['n n' ] | None
99- pde_labels : Int ['n' ] | None
100- resolved_labels : Int ['n' ] | None
110+ msa_mask : Bool [' s' ] | None = None
111+ distance_labels : Int ['n n' ] | None = None
112+ pae_labels : Int ['n n' ] | None = None
113+ pde_labels : Int [' n' ] | None = None
114+ resolved_labels : Int [' n' ] | None = None
101115
102116@typecheck
103117def molecule_to_atom_input (molecule_input : MoleculeInput ) -> AtomInput :
@@ -106,41 +120,79 @@ def molecule_to_atom_input(molecule_input: MoleculeInput) -> AtomInput:
106120# alphafold3 input - support polypeptides, nucleic acids, metal ions + any number of ligands + misc biomolecules
107121
108122@typecheck
109- class Alphafold3Input (TypedDict ):
110- proteins : List [Int ['_' ]]
111- protein_atom_lens : List [Int ['_' ]]
112- nucleic_acids : List [Int ['_' ]]
113- nucleic_acid_atom_lens : List [Int ['_' ]]
114- metal_ions : List [int ]
115- misc_molecule_ids : List [int ]
123+ @dataclass
124+ class Alphafold3Input :
125+ proteins : List [Int [' _' ]]
126+ protein_atom_lens : List [Int [' _' ]]
127+ nucleic_acids : List [Int [' _' ]]
128+ nucleic_acid_atom_lens : List [Int [' _' ]]
129+ metal_ions : Int [' _' ]
130+ misc_molecule_ids : Int [' _' ]
116131 ligands : List [Mol | str ] # can be given as smiles
117- atom_pos : List [Float ['_ 3' ]] | Float ['m 3' ] | None
132+ atom_pos : List [Float ['_ 3' ]] | Float ['m 3' ] | None = None
118133 templates : Float ['t n n dt' ]
119134 msa : Float ['s n dm' ]
120- template_mask : Bool ['t' ] | None
121- msa_mask : Bool ['s' ] | None
122- distance_labels : Int ['n n' ] | None
123- pae_labels : Int ['n n' ] | None
124- pde_labels : Int ['n' ] | None
125- resolved_labels : Int ['n' ] | None
135+ template_mask : Bool [' t' ] | None = None
136+ msa_mask : Bool [' s' ] | None = None
137+ distance_labels : Int ['n n' ] | None = None
138+ pae_labels : Int ['n n' ] | None = None
139+ pde_labels : Int [' n' ] | None = None
140+ resolved_labels : Int [' n' ] | None = None
126141
127142@typecheck
128143def af3_input_to_molecule_input (af3_input : Alphafold3Input ) -> AtomInput :
129144 raise NotImplementedError
130145
146+ # pdb input
147+
148+ @typecheck
149+ @dataclass
150+ class PDBInput :
151+ filepath : str
152+
153+ @typecheck
154+ def pdb_input_to_alphafold3_input (pdb_input : PDBInput ) -> Alphafold3Input :
155+ raise NotImplementedError
156+
131157# the config used for keeping track of all the disparate inputs and their transforms down to AtomInput
132158# this can be preprocessed or will be taken care of automatically within the Trainer during data collation
133159
134160INPUT_TO_ATOM_TRANSFORM = {
161+ AtomInput : identity ,
135162 MoleculeInput : molecule_to_atom_input ,
136- Alphafold3Input : compose (af3_input_to_molecule_input , molecule_to_atom_input )
163+ Alphafold3Input : compose (
164+ af3_input_to_molecule_input ,
165+ molecule_to_atom_input
166+ ),
167+ PDBInput : compose (
168+ pdb_input_to_alphafold3_input ,
169+ af3_input_to_molecule_input ,
170+ molecule_to_atom_input
171+ )
137172}
138173
139174# function for extending the config
140175
141176@typecheck
142177def register_input_transform (
143178 input_type : Type ,
144- fn : Callable [[TypedDict ], AtomInput ]
179+ fn : Callable [[Any ], AtomInput ]
145180):
181+ assert input_type not in INPUT_TO_ATOM_TRANSFORM , f'{ input_type } is already registered'
146182 INPUT_TO_ATOM_TRANSFORM [input_type ] = fn
183+
184+ # functions for transforming to atom inputs
185+
186+ def maybe_transform_to_atom_inputs (inputs : List [Any ]) -> List [AtomInput ]:
187+ atom_inputs = []
188+
189+ for i in inputs :
190+
191+ maybe_to_atom_fn = INPUT_TO_ATOM_TRANSFORM .get (type (i ), None )
192+
193+ if not exists (maybe_to_atom_fn ):
194+ raise TypeError (f'invalid input type { type (i )} being passed into Trainer that is not converted to AtomInput correctly' )
195+
196+ atom_inputs .append (maybe_to_atom_fn (i ))
197+
198+ return atom_inputs
0 commit comments