Skip to content

Commit 60c8bff

Browse files
committed
refactor into Alphafold3Input
1 parent 439a8ca commit 60c8bff

File tree

1 file changed

+12
-41
lines changed

1 file changed

+12
-41
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 12 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -103,66 +103,37 @@ class MoleculeInput(TypedDict):
103103
def molecule_to_atom_input(molecule_input: MoleculeInput) -> AtomInput:
104104
raise NotImplementedError
105105

106-
def validate_molecule_input(molecule_input: MoleculeInput):
107-
assert True
108-
109-
# residue level - single chain proteins for starters
106+
# alphafold3 input - support polypeptides, nucleic acids, metal ions + any number of ligands + misc biomolecules
110107

111108
@typecheck
112-
class SingleProteinInput(TypedDict):
113-
residue_ids: Int['n']
114-
residue_atom_lens: Int['n']
115-
templates: Float['t n n dt']
116-
msa: Float['s n dm']
117-
template_mask: Bool['t'] | None
118-
msa_mask: Bool['s'] | None
119-
atom_pos: Float['m 3'] | None
120-
distance_labels: Int['n n'] | None
121-
pae_labels: Int['n n'] | None
122-
pde_labels: Int['n'] | None
123-
resolved_labels: Int['n'] | None
124-
125-
@typecheck
126-
def single_protein_input_to_atom_input(
127-
input: SingleProteinInput
128-
) -> AtomInput:
129-
130-
raise NotImplementedError
131-
132-
# single chain protein with single ds nucleic acid
133-
134-
# o - for nucleOtide seq
135-
136-
@typecheck
137-
class SingleProteinSingleNucleicAcidInput(TypedDict):
138-
residue_ids: Int['n']
139-
residue_atom_lens: Int['n']
140-
nucleotide_ids: Int['o']
141-
nucleic_acid_type: Literal['dna', 'rna']
109+
class Alphafold3Input(TypedDict):
110+
proteins: List[Int['_']]
111+
protein_atom_lens: List[Int['_']]
112+
nucleic_acids: List[Int['_']]
113+
nucleic_acid_atom_lens: List[Int['_']]
114+
metal_ions: List[int]
115+
misc_molecule_ids: List[int]
116+
ligands: List[Mol | str] # can be given as smiles
117+
atom_pos: List[Float['_ 3']] | Float['m 3'] | None
142118
templates: Float['t n n dt']
143119
msa: Float['s n dm']
144120
template_mask: Bool['t'] | None
145121
msa_mask: Bool['s'] | None
146-
atom_pos: Float['m 3'] | None
147122
distance_labels: Int['n n'] | None
148123
pae_labels: Int['n n'] | None
149124
pde_labels: Int['n'] | None
150125
resolved_labels: Int['n'] | None
151126

152127
@typecheck
153-
def single_protein_input_and_single_nucleic_acid_to_atom_input(
154-
input: SingleProteinSingleNucleicAcidInput
155-
) -> AtomInput:
156-
128+
def af3_input_to_molecule_input(af3_input: Alphafold3Input) -> AtomInput:
157129
raise NotImplementedError
158130

159131
# the config used for keeping track of all the disparate inputs and their transforms down to AtomInput
160132
# this can be preprocessed or will be taken care of automatically within the Trainer during data collation
161133

162134
INPUT_TO_ATOM_TRANSFORM = {
163135
MoleculeInput: molecule_to_atom_input,
164-
SingleProteinInput: single_protein_input_to_atom_input,
165-
SingleProteinSingleNucleicAcidInput: single_protein_input_and_single_nucleic_acid_to_atom_input
136+
Alphafold3Input: compose(af3_input_to_molecule_input, molecule_to_atom_input)
166137
}
167138

168139
# function for extending the config

0 commit comments

Comments
 (0)