Skip to content

Commit 06ae6af

Browse files
committed
sketch out the general outlines for how all the disparate types of input data will be handled
1 parent 9ea3df4 commit 06ae6af

File tree

3 files changed

+54
-9
lines changed

3 files changed

+54
-9
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Int, Bool, Float
66
)
77

8-
# constants
8+
# atom level, what Alphafold3 accepts
99

1010
@typecheck
1111
class AtomInput(TypedDict):
@@ -25,3 +25,33 @@ class AtomInput(TypedDict):
2525
pae_labels: Int['*b n n'] | None
2626
pde_labels: Int['*b n'] | None
2727
resolved_labels: Int['*b n'] | None
28+
29+
# residue level - single chain proteins for starters
30+
31+
@typecheck
32+
class ProteinInput(TypedDict):
33+
residue_ids: Int['*b n']
34+
residue_atom_lens: Int['*b n']
35+
templates: Float['*b t n n dt']
36+
msa: Float['*b s n dm']
37+
template_mask: Bool['*b t'] | None
38+
msa_mask: Bool['*b s'] | None
39+
atom_pos: Float['*b m 3'] | None
40+
distance_labels: Int['*b n n'] | None
41+
pae_labels: Int['*b n n'] | None
42+
pde_labels: Int['*b n'] | None
43+
resolved_labels: Int['*b n'] | None
44+
45+
@typecheck
46+
def single_protein_input_to_atom_input(
47+
residue_input: ProteinInput
48+
) -> AtomInput:
49+
50+
raise NotImplementedError
51+
52+
# the config used for keeping track of all the disparate inputs and their transforms down to AtomInput
53+
# this can be preprocessed or will be taken care of automatically within the Trainer during data collation
54+
55+
INPUT_TO_ATOM_TRANSFORM = {
56+
ProteinInput: single_protein_input_to_atom_input
57+
}

alphafold3_pytorch/trainer.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
)
1616

1717
from alphafold3_pytorch.inputs import (
18-
AtomInput
18+
AtomInput,
19+
INPUT_TO_ATOM_TRANSFORM
1920
)
2021

2122
import torch
@@ -77,18 +78,31 @@ def collate_af3_inputs(
7778
if exists(map_input_fn):
7879
inputs = [map_input_fn(i) for i in inputs]
7980

80-
# make sure all inputs are AtomInput
81+
# go through all the inputs
82+
# and for any that is not AtomInput, try to transform it with the registered input type to corresponding registered function
8183

82-
assert all([beartype_isinstance(i, AtomInput) for i in inputs])
84+
atom_inputs = []
85+
86+
for i in inputs:
87+
if beartype_isinstance(i, AtomInput):
88+
atom_inputs.append(i)
89+
continue
90+
91+
maybe_to_atom_fn = INPUT_TO_ATOM_TRANSFORM.get(type(i), None)
92+
93+
if not exists(maybe_to_atom_fn):
94+
raise TypeError(f'invalid input type {type(i)} being passed into Trainer that is not converted to AtomInput correctly')
95+
96+
atom_inputs = maybe_to_atom_fn(i)
8397

8498
# separate input dictionary into keys and values
8599

86-
keys = inputs[0].keys()
87-
inputs = [i.values() for i in inputs]
100+
keys = atom_inputs[0].keys()
101+
atom_inputs = [i.values() for i in atom_inputs]
88102

89103
outputs = []
90104

91-
for grouped in zip(*inputs):
105+
for grouped in zip(*atom_inputs):
92106
# if all None, just return None
93107

94108
not_none_grouped = [*filter(exists, grouped)]
@@ -144,7 +158,8 @@ def collate_af3_inputs(
144158

145159
# reconstitute dictionary
146160

147-
return AtomInput(tuple(zip(keys, outputs)))
161+
batched_atom_inputs = AtomInput(tuple(zip(keys, outputs)))
162+
return batched_atom_inputs
148163

149164
@typecheck
150165
def DataLoader(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.44"
3+
version = "0.1.45"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)