11from __future__ import annotations
22
3+ from functools import wraps
34from pathlib import Path
45
56from alphafold3_pytorch .alphafold3 import Alphafold3
67from alphafold3_pytorch .attention import pad_at_dim
78
89from typing import TypedDict , List
10+
911from alphafold3_pytorch .typing import (
1012 typecheck ,
13+ beartype_isinstance ,
1114 Int , Bool , Float
1215)
1316
2528# constants
2629
2730@typecheck
28- class Alphafold3Input (TypedDict ):
31+ class AtomInput (TypedDict ):
2932 atom_inputs : Float ['m dai' ]
30- molecule_atom_lens : Int [' n' ]
33+ molecule_atom_lens : Int [' n' ]
3134 atompair_inputs : Float ['m m dapi' ] | Float ['nw w (w*2) dapi' ]
32- additional_molecule_feats : Float ['n 10' ]
35+ additional_molecule_feats : Float ['n 10' ]
3336 templates : Float ['t n n dt' ]
3437 msa : Float ['s n dm' ]
3538 template_mask : Bool [' t' ] | None
3639 msa_mask : Bool [' s' ] | None
3740 atom_pos : Float ['m 3' ] | None
38- molecule_atom_indices : Int [' n' ] | None
41+ molecule_atom_indices : Int [' n' ] | None
3942 distance_labels : Int ['n n' ] | None
4043 pae_labels : Int ['n n' ] | None
4144 pde_labels : Int [' n' ] | None
@@ -77,9 +80,18 @@ def accum_dict(
7780
7881@typecheck
7982def collate_af3_inputs (
80- inputs : List [Alphafold3Input ],
81- int_pad_value = - 1
83+ inputs : List ,
84+ int_pad_value = - 1 ,
85+ map_input_fn : Callable | None = None
8286):
87+
88+ if exists (map_input_fn ):
89+ inputs = [map_input_fn (i ) for i in inputs ]
90+
91+ # make sure all inputs are AtomInput
92+
93+ assert all ([beartype_isinstance (i , AtomInput ) for i in inputs ])
94+
8395 # separate input dictionary into keys and values
8496
8597 keys = inputs [0 ].keys ()
@@ -145,8 +157,18 @@ def collate_af3_inputs(
145157
146158 return dict (tuple (zip (keys , outputs )))
147159
148- def DataLoader (* args , ** kwargs ):
149- return OrigDataLoader (* args , collate_fn = collate_af3_inputs , ** kwargs )
160+ @typecheck
161+ def DataLoader (
162+ * args ,
163+ map_input_fn : Callable | None = None ,
164+ ** kwargs
165+ ):
166+ collate_fn = collate_af3_inputs
167+
168+ if exists (map_input_fn ):
169+ collate_fn = partial (collate_fn , map_input_fn = map_input_fn )
170+
171+ return OrigDataLoader (* args , collate_fn = collate_fn , ** kwargs )
150172
151173# default scheduler used in paper w/ warmup
152174
@@ -175,6 +197,7 @@ def __init__(
175197 num_train_steps : int ,
176198 batch_size : int ,
177199 grad_accum_every : int = 1 ,
200+ map_dataset_input_fn : Callable | None = None ,
178201 valid_dataset : Dataset | None = None ,
179202 valid_every : int = 1000 ,
180203 test_dataset : Dataset | None = None ,
@@ -229,9 +252,16 @@ def __init__(
229252
230253 self .optimizer = optimizer
231254
255+ # if map dataset function given, curry into DataLoader
256+
257+ DataLoader_ = DataLoader
258+
259+ if exists (map_dataset_input_fn ):
260+ DataLoader_ = partial (DataLoader_ , map_input_fn = map_dataset_input_fn )
261+
232262 # train dataloader
233263
234- self .dataloader = DataLoader (dataset , batch_size = batch_size , shuffle = True , drop_last = True )
264+ self .dataloader = DataLoader_ (dataset , batch_size = batch_size , shuffle = True , drop_last = True )
235265
236266 # validation dataloader on the EMA model
237267
@@ -241,15 +271,15 @@ def __init__(
241271
242272 if self .needs_valid and self .is_main :
243273 self .valid_dataset_size = len (valid_dataset )
244- self .valid_dataloader = DataLoader (valid_dataset , batch_size = batch_size )
274+ self .valid_dataloader = DataLoader_ (valid_dataset , batch_size = batch_size )
245275
246276 # testing dataloader on EMA model
247277
248278 self .needs_test = exists (test_dataset )
249279
250280 if self .needs_test and self .is_main :
251281 self .test_dataset_size = len (test_dataset )
252- self .test_dataloader = DataLoader (test_dataset , batch_size = batch_size )
282+ self .test_dataloader = DataLoader_ (test_dataset , batch_size = batch_size )
253283
254284 # training steps and num gradient accum steps
255285
0 commit comments