Skip to content

Commit b91ff6e

Browse files
committed
test on the atom inputs collation function individually
1 parent e3aa3e2 commit b91ff6e

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

alphafold3_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from alphafold3_pytorch.trainer import (
4343
Trainer,
4444
DataLoader,
45+
collate_af3_inputs
4546
)
4647

4748
from alphafold3_pytorch.configs import (

tests/test_trainer.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
DataLoader,
1515
Trainer,
1616
ConductorConfig,
17+
collate_af3_inputs,
1718
create_trainer_from_yaml,
1819
create_trainer_from_conductor_yaml,
1920
create_alphafold3_from_yaml
@@ -186,6 +187,38 @@ def test_trainer():
186187

187188
alphafold3 = Alphafold3.init_and_load('./some/nested/folder2/training.pt')
188189

190+
# test use of collation fn outside of trainer
191+
192+
def test_collate_fn():
193+
alphafold3 = Alphafold3(
194+
dim_atom_inputs = 77,
195+
dim_template_feats = 44,
196+
num_dist_bins = 38,
197+
confidence_head_kwargs = dict(
198+
pairformer_depth = 1
199+
),
200+
template_embedder_kwargs = dict(
201+
pairformer_stack_depth = 1
202+
),
203+
msa_module_kwargs = dict(
204+
depth = 1
205+
),
206+
pairformer_stack = dict(
207+
depth = 1
208+
),
209+
diffusion_module_kwargs = dict(
210+
atom_encoder_depth = 1,
211+
token_transformer_depth = 1,
212+
atom_decoder_depth = 1,
213+
),
214+
)
215+
216+
dataset = MockAtomDataset(1)
217+
218+
batched_atom_inputs = collate_af3_inputs([dataset[0]])
219+
220+
_, breakdown = alphafold3(**batched_atom_inputs, return_loss_breakdown = True)
221+
189222
# test creating trainer + alphafold3 from config
190223

191224
def test_trainer_config():

0 commit comments

Comments
 (0)