Skip to content

Commit 5fe0cbd

Browse files
committed
add test dataset
1 parent 585aca0 commit 5fe0cbd

File tree

3 files changed

+44
-2
lines changed

3 files changed

+44
-2
lines changed

alphafold3_pytorch/trainer.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __init__(
9797
grad_accum_every: int = 1,
9898
valid_dataset: Dataset | None = None,
9999
valid_every: int = 1000,
100+
test_dataset: Dataset | None = None,
100101
optimizer: Optimizer | None = None,
101102
scheduler: LRScheduler | None = None,
102103
ema_decay = 0.999,
@@ -159,6 +160,14 @@ def __init__(
159160
self.valid_dataset_size = len(valid_dataset)
160161
self.valid_dataloader = DataLoader(valid_dataset, batch_size = batch_size)
161162

163+
# testing dataloader on EMA model
164+
165+
self.needs_test = exists(test_dataset)
166+
167+
if self.needs_test and self.is_main:
168+
self.test_dataset_size = len(test_dataset)
169+
self.test_dataloader = DataLoader(test_dataset, batch_size = batch_size)
170+
162171
# training steps and num gradient accum steps
163172

164173
self.num_train_steps = num_train_steps
@@ -347,4 +356,35 @@ def __call__(
347356

348357
self.wait()
349358

359+
# maybe test
360+
361+
if self.is_main and self.needs_test:
362+
with torch.no_grad():
363+
self.ema_model.eval()
364+
365+
total_test_loss = 0.
366+
test_loss_breakdown = None
367+
368+
for test_batch in self.test_dataloader:
369+
test_loss, loss_breakdown = self.ema_model(
370+
**test_batch,
371+
return_loss_breakdown = True
372+
)
373+
374+
test_batch_size = test_batch.get('atom_inputs').shape[0]
375+
scale = test_batch_size / self.test_dataset_size
376+
377+
total_test_loss += test_loss.item() * scale
378+
test_loss_breakdown = accum_dict(test_loss_breakdown, loss_breakdown._asdict(), scale = scale)
379+
380+
self.print(f'test loss: {total_test_loss:.3f}')
381+
382+
# prepend test_ to all losses for logging
383+
384+
test_loss_breakdown = {f'test_{k}':v for k, v in test_loss_breakdown.items()}
385+
386+
# log
387+
388+
self.log(**test_loss_breakdown)
389+
350390
print(f'training complete')

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.58"
3+
version = "0.0.59"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def test_trainer():
9696
)
9797

9898
dataset = MockAtomDataset(100)
99-
valid_dataset = MockAtomDataset(2)
99+
valid_dataset = MockAtomDataset(4)
100+
test_dataset = MockAtomDataset(2)
100101

101102
# test saving and loading from Alphafold3, independent of lightning
102103

@@ -126,6 +127,7 @@ def test_trainer():
126127
alphafold3,
127128
dataset = dataset,
128129
valid_dataset = valid_dataset,
130+
test_dataset = test_dataset,
129131
accelerator = 'cpu',
130132
num_train_steps = 2,
131133
batch_size = 1,

0 commit comments

Comments
 (0)