Skip to content

Commit 2cbed8d

Browse files
committed
more preparation for the different types of data coming in from fine tuning stages
1 parent f37f88d commit 2cbed8d

File tree

5 files changed

+57
-17
lines changed

5 files changed

+57
-17
lines changed

alphafold3_pytorch/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from alphafold3_pytorch.trainer import (
3636
Trainer,
3737
DataLoader,
38-
Alphafold3Input
38+
AtomInput
3939
)
4040

4141
__all__ = [
@@ -66,6 +66,6 @@
6666
ConfidenceHead,
6767
DistogramHead,
6868
Alphafold3,
69-
Alphafold3Input,
69+
AtomInput,
7070
Trainer
7171
]

alphafold3_pytorch/trainer.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
from __future__ import annotations
22

3+
from functools import wraps
34
from pathlib import Path
45

56
from alphafold3_pytorch.alphafold3 import Alphafold3
67
from alphafold3_pytorch.attention import pad_at_dim
78

89
from typing import TypedDict, List
10+
911
from alphafold3_pytorch.typing import (
1012
typecheck,
13+
beartype_isinstance,
1114
Int, Bool, Float
1215
)
1316

@@ -25,17 +28,17 @@
2528
# constants
2629

2730
@typecheck
28-
class Alphafold3Input(TypedDict):
31+
class AtomInput(TypedDict):
2932
atom_inputs: Float['m dai']
30-
molecule_atom_lens: Int[' n']
33+
molecule_atom_lens: Int[' n']
3134
atompair_inputs: Float['m m dapi'] | Float['nw w (w*2) dapi']
32-
additional_molecule_feats: Float['n 10']
35+
additional_molecule_feats: Float['n 10']
3336
templates: Float['t n n dt']
3437
msa: Float['s n dm']
3538
template_mask: Bool[' t'] | None
3639
msa_mask: Bool[' s'] | None
3740
atom_pos: Float['m 3'] | None
38-
molecule_atom_indices: Int[' n'] | None
41+
molecule_atom_indices: Int[' n'] | None
3942
distance_labels: Int['n n'] | None
4043
pae_labels: Int['n n'] | None
4144
pde_labels: Int[' n'] | None
@@ -77,9 +80,18 @@ def accum_dict(
7780

7881
@typecheck
7982
def collate_af3_inputs(
80-
inputs: List[Alphafold3Input],
81-
int_pad_value = -1
83+
inputs: List,
84+
int_pad_value = -1,
85+
map_input_fn: Callable | None = None
8286
):
87+
88+
if exists(map_input_fn):
89+
inputs = [map_input_fn(i) for i in inputs]
90+
91+
# make sure all inputs are AtomInput
92+
93+
assert all([beartype_isinstance(i, AtomInput) for i in inputs])
94+
8395
# separate input dictionary into keys and values
8496

8597
keys = inputs[0].keys()
@@ -145,8 +157,18 @@ def collate_af3_inputs(
145157

146158
return dict(tuple(zip(keys, outputs)))
147159

148-
def DataLoader(*args, **kwargs):
149-
return OrigDataLoader(*args, collate_fn = collate_af3_inputs, **kwargs)
160+
@typecheck
161+
def DataLoader(
162+
*args,
163+
map_input_fn: Callable | None = None,
164+
**kwargs
165+
):
166+
collate_fn = collate_af3_inputs
167+
168+
if exists(map_input_fn):
169+
collate_fn = partial(collate_fn, map_input_fn = map_input_fn)
170+
171+
return OrigDataLoader(*args, collate_fn = collate_fn, **kwargs)
150172

151173
# default scheduler used in paper w/ warmup
152174

@@ -175,6 +197,7 @@ def __init__(
175197
num_train_steps: int,
176198
batch_size: int,
177199
grad_accum_every: int = 1,
200+
map_dataset_input_fn: Callable | None = None,
178201
valid_dataset: Dataset | None = None,
179202
valid_every: int = 1000,
180203
test_dataset: Dataset | None = None,
@@ -229,9 +252,16 @@ def __init__(
229252

230253
self.optimizer = optimizer
231254

255+
# if map dataset function given, curry into DataLoader
256+
257+
DataLoader_ = DataLoader
258+
259+
if exists(map_dataset_input_fn):
260+
DataLoader_ = partial(DataLoader_, map_input_fn = map_dataset_input_fn)
261+
232262
# train dataloader
233263

234-
self.dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = True, drop_last = True)
264+
self.dataloader = DataLoader_(dataset, batch_size = batch_size, shuffle = True, drop_last = True)
235265

236266
# validation dataloader on the EMA model
237267

@@ -241,15 +271,15 @@ def __init__(
241271

242272
if self.needs_valid and self.is_main:
243273
self.valid_dataset_size = len(valid_dataset)
244-
self.valid_dataloader = DataLoader(valid_dataset, batch_size = batch_size)
274+
self.valid_dataloader = DataLoader_(valid_dataset, batch_size = batch_size)
245275

246276
# testing dataloader on EMA model
247277

248278
self.needs_test = exists(test_dataset)
249279

250280
if self.needs_test and self.is_main:
251281
self.test_dataset_size = len(test_dataset)
252-
self.test_dataloader = DataLoader(test_dataset, batch_size = batch_size)
282+
self.test_dataloader = DataLoader_(test_dataset, batch_size = batch_size)
253283

254284
# training steps and num gradient accum steps
255285

alphafold3_pytorch/typing.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from torch import Tensor
55

66
from beartype import beartype
7+
from beartype.door import is_bearable
8+
79
from jaxtyping import (
810
Float,
911
Int,
@@ -18,6 +20,11 @@
1820

1921
# function
2022

23+
def always(value):
24+
def inner(*args, **kwargs):
25+
return value
26+
return inner
27+
2128
def null_decorator(fn):
2229
@wraps(fn)
2330
def inner(*args, **kwargs):
@@ -43,9 +50,12 @@ def __getitem__(self, shapes: str):
4350

4451
typecheck = jaxtyped(typechecker = beartype) if should_typecheck else null_decorator
4552

53+
beartype_isinstance = is_bearable if should_typecheck else always(True)
54+
4655
__all__ = [
4756
Float,
4857
Int,
4958
Bool,
50-
typecheck
59+
typecheck,
60+
beartype_isinstance
5161
]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.8"
3+
version = "0.1.9"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from alphafold3_pytorch import (
1212
Alphafold3,
13-
Alphafold3Input,
13+
AtomInput,
1414
DataLoader,
1515
Trainer
1616
)
@@ -61,7 +61,7 @@ def __getitem__(self, idx):
6161
plddt_labels = torch.randint(0, 50, (seq_len,))
6262
resolved_labels = torch.randint(0, 2, (seq_len,))
6363

64-
return Alphafold3Input(
64+
return AtomInput(
6565
atom_inputs = atom_inputs,
6666
atompair_inputs = atompair_inputs,
6767
molecule_atom_lens = molecule_atom_lens,

0 commit comments

Comments
 (0)