Skip to content

Commit 3bd8bf4

Browse files
committed
switch to using dataclasses for the inputs
1 parent 60c8bff commit 3bd8bf4

File tree

5 files changed

+134
-92
lines changed

5 files changed

+134
-92
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@
3838
full_pairwise_repr_to_windowed
3939
)
4040

41+
from alphafold3_pytorch.inputs import (
42+
IS_MOLECULE_TYPES,
43+
ADDITIONAL_MOLECULE_FEATS
44+
)
45+
4146
from frame_averaging_pytorch import FrameAverage
4247

4348
from taylor_series_linear_attention import TaylorSeriesLinearAttn
@@ -98,11 +103,6 @@
98103

99104
# constants
100105

101-
from alphafold3_pytorch.inputs import (
102-
IS_MOLECULE_TYPES,
103-
ADDITIONAL_MOLECULE_FEATS
104-
)
105-
106106
LinearNoBias = partial(Linear, bias = False)
107107

108108
# helper functions

alphafold3_pytorch/inputs.py

Lines changed: 108 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from functools import wraps
2-
from typing import Type, TypedDict, Literal, Callable, List
2+
from dataclasses import dataclass
3+
from typing import Type, Literal, Callable, List, Any
34

45
from rdkit import Chem
56
from rdkit.Chem.rdchem import Mol
67

78
from alphafold3_pytorch.tensor_typing import (
89
typecheck,
10+
beartype_isinstance,
911
Int, Bool, Float
1012
)
1113

@@ -21,10 +23,17 @@
2123
IS_MOLECULE_TYPES = 4
2224
ADDITIONAL_MOLECULE_FEATS = 5
2325

24-
# simple compose function
25-
# for chaining from Alphafold3Input -> MoleculeInput -> AtomInput
26+
# functions
27+
28+
def exists(v):
29+
return v is not None
30+
31+
def identity(t):
32+
return t
2633

2734
def compose(*fns: Callable):
35+
# for chaining from Alphafold3Input -> MoleculeInput -> AtomInput
36+
2837
def inner(x, *args, **kwargs):
2938
for fn in fns:
3039
x = fn(x, *args, **kwargs)
@@ -34,30 +43,33 @@ def inner(x, *args, **kwargs):
3443
# atom level, what Alphafold3 accepts
3544

3645
@typecheck
37-
class AtomInput(TypedDict):
46+
@dataclass
47+
class AtomInput:
3848
atom_inputs: Float['m dai']
39-
molecule_ids: Int['n']
40-
molecule_atom_lens: Int['n']
49+
molecule_ids: Int[' n']
50+
molecule_atom_lens: Int[' n']
4151
atompair_inputs: Float['m m dapi'] | Float['nw w (w*2) dapi']
4252
additional_molecule_feats: Float[f'n {ADDITIONAL_MOLECULE_FEATS}']
4353
is_molecule_types: Bool[f'n {IS_MOLECULE_TYPES}']
4454
templates: Float['t n n dt']
4555
msa: Float['s n dm']
46-
token_bonds: Bool['n n'] | None
47-
atom_ids: Int['m'] | None
48-
atom_parent_ids: Int['m'] | None
49-
atompair_ids: Int['m m'] | Int['nw w (w*2)'] | None
50-
template_mask: Bool['t'] | None
51-
msa_mask: Bool['s'] | None
52-
atom_pos: Float['m 3'] | None
53-
molecule_atom_indices: Int['n'] | None
54-
distance_labels: Int['n n'] | None
55-
pae_labels: Int['n n'] | None
56-
pde_labels: Int['n'] | None
57-
resolved_labels: Int['n'] | None
56+
token_bonds: Bool['n n'] | None = None
57+
atom_ids: Int[' m'] | None = None
58+
atom_parent_ids: Int[' m'] | None = None
59+
atompair_ids: Int['m m'] | Int['nw w (w*2)'] | None = None
60+
template_mask: Bool[' t'] | None = None
61+
msa_mask: Bool[' s'] | None = None
62+
atom_pos: Float['m 3'] | None = None
63+
molecule_atom_indices: Int[' n'] | None = None
64+
distance_labels: Int['n n'] | None = None
65+
pae_labels: Int['n n'] | None = None
66+
pde_labels: Int['n n'] | None = None
67+
plddt_labels: Int[' n'] | None = None
68+
resolved_labels: Int[' n'] | None = None
5869

5970
@typecheck
60-
class BatchedAtomInput(TypedDict):
71+
@dataclass
72+
class BatchedAtomInput:
6173
atom_inputs: Float['b m dai']
6274
molecule_ids: Int['b n']
6375
molecule_atom_lens: Int['b n']
@@ -66,38 +78,40 @@ class BatchedAtomInput(TypedDict):
6678
is_molecule_types: Bool[f'b n {IS_MOLECULE_TYPES}']
6779
templates: Float['b t n n dt']
6880
msa: Float['b s n dm']
69-
token_bonds: Bool['b n n'] | None
70-
atom_ids: Int['b m'] | None
71-
atom_parent_ids: Int['b m'] | None
72-
atompair_ids: Int['b m m'] | Int['b nw w (w*2)'] | None
73-
template_mask: Bool['b t'] | None
74-
msa_mask: Bool['b s'] | None
75-
atom_pos: Float['b m 3'] | None
76-
molecule_atom_indices: Int['b n'] | None
77-
distance_labels: Int['b n n'] | None
78-
pae_labels: Int['b n n'] | None
79-
pde_labels: Int['b n'] | None
80-
resolved_labels: Int['b n'] | None
81+
token_bonds: Bool['b n n'] | None = None
82+
atom_ids: Int['b m'] | None = None
83+
atom_parent_ids: Int['b m'] | None = None
84+
atompair_ids: Int['b m m'] | Int['b nw w (w*2)'] | None = None
85+
template_mask: Bool['b t'] | None = None
86+
msa_mask: Bool['b s'] | None = None
87+
atom_pos: Float['b m 3'] | None = None
88+
molecule_atom_indices: Int['b n'] | None = None
89+
distance_labels: Int['b n n'] | None = None
90+
pae_labels: Int['b n n'] | None = None
91+
pde_labels: Int['b n n'] | None = None
92+
plddt_labels: Int['b n'] | None = None
93+
resolved_labels: Int['b n'] | None = None
8194

8295
# molecule input - accepting list of molecules as rdchem.Mol + the atomic lengths for how to pool into tokens
8396

8497
@typecheck
85-
class MoleculeInput(TypedDict):
98+
@dataclass
99+
class MoleculeInput:
86100
molecules: List[Mol]
87101
molecule_token_pool_lens: List[List[int]]
88102
molecule_atom_indices: List[List[int] | None]
89-
molecule_ids: Int['n']
103+
molecule_ids: Int[' n']
90104
additional_molecule_feats: Float['n 5']
91105
is_molecule_types: Bool['n 4']
92-
atom_pos: List[Float['_ 3']] | Float['m 3'] | None
106+
atom_pos: List[Float['_ 3']] | Float['m 3'] | None = None
93107
templates: Float['t n n dt']
94-
template_mask: Bool['t'] | None
108+
template_mask: Bool[' t'] | None = None
95109
msa: Float['s n dm']
96-
msa_mask: Bool['s'] | None
97-
distance_labels: Int['n n'] | None
98-
pae_labels: Int['n n'] | None
99-
pde_labels: Int['n'] | None
100-
resolved_labels: Int['n'] | None
110+
msa_mask: Bool[' s'] | None = None
111+
distance_labels: Int['n n'] | None = None
112+
pae_labels: Int['n n'] | None = None
113+
pde_labels: Int[' n'] | None = None
114+
resolved_labels: Int[' n'] | None = None
101115

102116
@typecheck
103117
def molecule_to_atom_input(molecule_input: MoleculeInput) -> AtomInput:
@@ -106,41 +120,79 @@ def molecule_to_atom_input(molecule_input: MoleculeInput) -> AtomInput:
106120
# alphafold3 input - support polypeptides, nucleic acids, metal ions + any number of ligands + misc biomolecules
107121

108122
@typecheck
109-
class Alphafold3Input(TypedDict):
110-
proteins: List[Int['_']]
111-
protein_atom_lens: List[Int['_']]
112-
nucleic_acids: List[Int['_']]
113-
nucleic_acid_atom_lens: List[Int['_']]
114-
metal_ions: List[int]
115-
misc_molecule_ids: List[int]
123+
@dataclass
124+
class Alphafold3Input:
125+
proteins: List[Int[' _']]
126+
protein_atom_lens: List[Int[' _']]
127+
nucleic_acids: List[Int[' _']]
128+
nucleic_acid_atom_lens: List[Int[' _']]
129+
metal_ions: Int[' _']
130+
misc_molecule_ids: Int[' _']
116131
ligands: List[Mol | str] # can be given as smiles
117-
atom_pos: List[Float['_ 3']] | Float['m 3'] | None
132+
atom_pos: List[Float['_ 3']] | Float['m 3'] | None = None
118133
templates: Float['t n n dt']
119134
msa: Float['s n dm']
120-
template_mask: Bool['t'] | None
121-
msa_mask: Bool['s'] | None
122-
distance_labels: Int['n n'] | None
123-
pae_labels: Int['n n'] | None
124-
pde_labels: Int['n'] | None
125-
resolved_labels: Int['n'] | None
135+
template_mask: Bool[' t'] | None = None
136+
msa_mask: Bool[' s'] | None = None
137+
distance_labels: Int['n n'] | None = None
138+
pae_labels: Int['n n'] | None = None
139+
pde_labels: Int[' n'] | None = None
140+
resolved_labels: Int[' n'] | None = None
126141

127142
@typecheck
128143
def af3_input_to_molecule_input(af3_input: Alphafold3Input) -> AtomInput:
129144
raise NotImplementedError
130145

146+
# pdb input
147+
148+
@typecheck
149+
@dataclass
150+
class PDBInput:
151+
filepath: str
152+
153+
@typecheck
154+
def pdb_input_to_alphafold3_input(pdb_input: PDBInput) -> Alphafold3Input:
155+
raise NotImplementedError
156+
131157
# the config used for keeping track of all the disparate inputs and their transforms down to AtomInput
132158
# this can be preprocessed or will be taken care of automatically within the Trainer during data collation
133159

134160
INPUT_TO_ATOM_TRANSFORM = {
161+
AtomInput: identity,
135162
MoleculeInput: molecule_to_atom_input,
136-
Alphafold3Input: compose(af3_input_to_molecule_input, molecule_to_atom_input)
163+
Alphafold3Input: compose(
164+
af3_input_to_molecule_input,
165+
molecule_to_atom_input
166+
),
167+
PDBInput: compose(
168+
pdb_input_to_alphafold3_input,
169+
af3_input_to_molecule_input,
170+
molecule_to_atom_input
171+
)
137172
}
138173

139174
# function for extending the config
140175

141176
@typecheck
142177
def register_input_transform(
143178
input_type: Type,
144-
fn: Callable[[TypedDict], AtomInput]
179+
fn: Callable[[Any], AtomInput]
145180
):
181+
assert input_type not in INPUT_TO_ATOM_TRANSFORM, f'{input_type} is already registered'
146182
INPUT_TO_ATOM_TRANSFORM[input_type] = fn
183+
184+
# functions for transforming to atom inputs
185+
186+
def maybe_transform_to_atom_inputs(inputs: List[Any]) -> List[AtomInput]:
187+
atom_inputs = []
188+
189+
for i in inputs:
190+
191+
maybe_to_atom_fn = INPUT_TO_ATOM_TRANSFORM.get(type(i), None)
192+
193+
if not exists(maybe_to_atom_fn):
194+
raise TypeError(f'invalid input type {type(i)} being passed into Trainer that is not converted to AtomInput correctly')
195+
196+
atom_inputs.append(maybe_to_atom_fn(i))
197+
198+
return atom_inputs

alphafold3_pytorch/trainer.py

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from functools import wraps, partial
4+
from dataclasses import asdict
45
from pathlib import Path
56

67
from alphafold3_pytorch.alphafold3 import Alphafold3
@@ -10,7 +11,6 @@
1011

1112
from alphafold3_pytorch.tensor_typing import (
1213
typecheck,
13-
beartype_isinstance,
1414
Int, Bool, Float
1515
)
1616

@@ -22,7 +22,7 @@
2222
from alphafold3_pytorch.inputs import (
2323
AtomInput,
2424
BatchedAtomInput,
25-
INPUT_TO_ATOM_TRANSFORM
25+
maybe_transform_to_atom_inputs
2626
)
2727

2828
import torch
@@ -88,43 +88,32 @@ def collate_af3_inputs(
8888
# go through all the inputs
8989
# and for any that is not AtomInput, try to transform it with the registered input type to corresponding registered function
9090

91-
atom_inputs = []
92-
93-
for i in inputs:
94-
if beartype_isinstance(i, AtomInput):
95-
atom_inputs.append(i)
96-
continue
97-
98-
maybe_to_atom_fn = INPUT_TO_ATOM_TRANSFORM.get(type(i), None)
99-
100-
if not exists(maybe_to_atom_fn):
101-
raise TypeError(f'invalid input type {type(i)} being passed into Trainer that is not converted to AtomInput correctly')
102-
103-
atom_inputs.append(maybe_to_atom_fn(i))
91+
atom_inputs = maybe_transform_to_atom_inputs(inputs)
10492

10593
# take care of windowing the atompair_inputs and atompair_ids if they are not windowed already
10694

10795
if exists(atoms_per_window):
10896
for atom_input in atom_inputs:
109-
atompair_inputs = atom_input['atompair_inputs']
110-
atompair_ids = atom_input.get('atompair_ids', None)
97+
atompair_inputs = atom_input.atompair_inputs
98+
atompair_ids = atom_input.atompair_ids
11199

112100
atompair_inputs_is_windowed = atompair_inputs.ndim == 4
113101

114102
if not atompair_inputs_is_windowed:
115-
atom_input['atompair_inputs'] = full_pairwise_repr_to_windowed(atompair_inputs, window_size = atoms_per_window)
103+
atom_input.atompair_inputs = full_pairwise_repr_to_windowed(atompair_inputs, window_size = atoms_per_window)
116104

117105
if exists(atompair_ids):
118106
atompair_ids_is_windowed = atompair_ids.ndim == 3
119107

120108
if not atompair_ids_is_windowed:
121-
atom_input['atompair_ids'] = full_attn_bias_to_windowed(atompair_ids, window_size = atoms_per_window)
109+
atom_input.atompair_ids = full_attn_bias_to_windowed(atompair_ids, window_size = atoms_per_window)
122110

123111
# separate input dictionary into keys and values
124112

125-
keys = atom_inputs[0].keys()
126-
atom_inputs = [i.values() for i in atom_inputs]
113+
keys = asdict(atom_inputs[0]).keys()
114+
atom_inputs = [asdict(i).values() for i in atom_inputs]
127115

116+
print(keys)
128117
outputs = []
129118

130119
for grouped in zip(*atom_inputs):
@@ -183,7 +172,7 @@ def collate_af3_inputs(
183172

184173
# reconstitute dictionary
185174

186-
batched_atom_inputs = BatchedAtomInput(tuple(zip(keys, outputs)))
175+
batched_atom_inputs = BatchedAtomInput(**dict(tuple(zip(keys, outputs))))
187176
return batched_atom_inputs
188177

189178
@typecheck
@@ -522,7 +511,7 @@ def __call__(
522511
# model forwards
523512

524513
loss, loss_breakdown = self.model(
525-
**inputs,
514+
**asdict(inputs),
526515
return_loss_breakdown = True
527516
)
528517

@@ -582,11 +571,11 @@ def __call__(
582571

583572
for valid_batch in self.valid_dataloader:
584573
valid_loss, loss_breakdown = self.ema_model(
585-
**valid_batch,
574+
**asdict(valid_batch),
586575
return_loss_breakdown = True
587576
)
588577

589-
valid_batch_size = valid_batch.get('atom_inputs').shape[0]
578+
valid_batch_size = valid_batch.atom_inputs.shape[0]
590579
scale = valid_batch_size / self.valid_dataset_size
591580

592581
total_valid_loss += valid_loss.item() * scale
@@ -620,11 +609,11 @@ def __call__(
620609

621610
for test_batch in self.test_dataloader:
622611
test_loss, loss_breakdown = self.ema_model(
623-
**test_batch,
612+
**asdict(test_batch),
624613
return_loss_breakdown = True
625614
)
626615

627-
test_batch_size = test_batch.get('atom_inputs').shape[0]
616+
test_batch_size = test_batch.atom_inputs.shape[0]
628617
scale = test_batch_size / self.test_dataset_size
629618

630619
total_test_loss += test_loss.item() * scale

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.69"
3+
version = "0.1.70"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)