@@ -103,66 +103,37 @@ class MoleculeInput(TypedDict):
103103def molecule_to_atom_input (molecule_input : MoleculeInput ) -> AtomInput :
104104 raise NotImplementedError
105105
106- def validate_molecule_input (molecule_input : MoleculeInput ):
107- assert True
108-
109- # residue level - single chain proteins for starters
106+ # alphafold3 input - support polypeptides, nucleic acids, metal ions + any number of ligands + misc biomolecules
110107
111108@typecheck
112- class SingleProteinInput (TypedDict ):
113- residue_ids : Int ['n' ]
114- residue_atom_lens : Int ['n' ]
115- templates : Float ['t n n dt' ]
116- msa : Float ['s n dm' ]
117- template_mask : Bool ['t' ] | None
118- msa_mask : Bool ['s' ] | None
119- atom_pos : Float ['m 3' ] | None
120- distance_labels : Int ['n n' ] | None
121- pae_labels : Int ['n n' ] | None
122- pde_labels : Int ['n' ] | None
123- resolved_labels : Int ['n' ] | None
124-
125- @typecheck
126- def single_protein_input_to_atom_input (
127- input : SingleProteinInput
128- ) -> AtomInput :
129-
130- raise NotImplementedError
131-
132- # single chain protein with single ds nucleic acid
133-
134- # o - for nucleOtide seq
135-
136- @typecheck
137- class SingleProteinSingleNucleicAcidInput (TypedDict ):
138- residue_ids : Int ['n' ]
139- residue_atom_lens : Int ['n' ]
140- nucleotide_ids : Int ['o' ]
141- nucleic_acid_type : Literal ['dna' , 'rna' ]
109+ class Alphafold3Input (TypedDict ):
110+ proteins : List [Int ['_' ]]
111+ protein_atom_lens : List [Int ['_' ]]
112+ nucleic_acids : List [Int ['_' ]]
113+ nucleic_acid_atom_lens : List [Int ['_' ]]
114+ metal_ions : List [int ]
115+ misc_molecule_ids : List [int ]
116+ ligands : List [Mol | str ] # can be given as smiles
117+ atom_pos : List [Float ['_ 3' ]] | Float ['m 3' ] | None
142118 templates : Float ['t n n dt' ]
143119 msa : Float ['s n dm' ]
144120 template_mask : Bool ['t' ] | None
145121 msa_mask : Bool ['s' ] | None
146- atom_pos : Float ['m 3' ] | None
147122 distance_labels : Int ['n n' ] | None
148123 pae_labels : Int ['n n' ] | None
149124 pde_labels : Int ['n' ] | None
150125 resolved_labels : Int ['n' ] | None
151126
152127@typecheck
153- def single_protein_input_and_single_nucleic_acid_to_atom_input (
154- input : SingleProteinSingleNucleicAcidInput
155- ) -> AtomInput :
156-
128+ def af3_input_to_molecule_input (af3_input : Alphafold3Input ) -> AtomInput :
157129 raise NotImplementedError
158130
159131# the config used for keeping track of all the disparate inputs and their transforms down to AtomInput
160132# this can be preprocessed or will be taken care of automatically within the Trainer during data collation
161133
162134INPUT_TO_ATOM_TRANSFORM = {
163135 MoleculeInput : molecule_to_atom_input ,
164- SingleProteinInput : single_protein_input_to_atom_input ,
165- SingleProteinSingleNucleicAcidInput : single_protein_input_and_single_nucleic_acid_to_atom_input
136+ Alphafold3Input : compose (af3_input_to_molecule_input , molecule_to_atom_input )
166137}
167138
168139# function for extending the config
0 commit comments