Skip to content

Commit 618f323

Browse files
committed
basic validation loop
1 parent d52c3c0 commit 618f323

File tree

2 files changed

+82
-4
lines changed

2 files changed

+82
-4
lines changed

alphafold3_pytorch/trainer.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ def exists(val):
4444
def default(v, d):
4545
return v if exists(v) else d
4646

47+
def divisible_by(num, den):
48+
return (num % den) == 0
49+
4750
def cycle(dataloader: DataLoader):
4851
while True:
4952
for batch in dataloader:
@@ -74,6 +77,8 @@ def __init__(
7477
num_train_steps: int,
7578
batch_size: int,
7679
grad_accum_every: int = 1,
80+
valid_dataset: Dataset | None = None,
81+
valid_every: int = 1000,
7782
optimizer: Optimizer | None = None,
7883
scheduler: LRScheduler | None = None,
7984
ema_decay = 0.999,
@@ -122,10 +127,22 @@ def __init__(
122127

123128
self.optimizer = optimizer
124129

125-
# data
130+
# train dataloader
126131

127132
self.dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = True, drop_last = True)
128133

134+
# validation dataloader on the EMA model
135+
136+
self.valid_every = valid_every
137+
138+
self.needs_valid = exists(valid_dataset)
139+
140+
if self.needs_valid and self.is_main:
141+
self.valid_dataset_size = len(valid_dataset)
142+
self.valid_dataloader = DataLoader(valid_dataset, batch_size = batch_size)
143+
144+
# training steps and num gradient accum steps
145+
129146
self.num_train_steps = num_train_steps
130147
self.grad_accum_every = grad_accum_every
131148

@@ -154,6 +171,9 @@ def __init__(
154171
def is_main(self):
155172
return self.fabric.global_rank == 0
156173

174+
def wait(self):
175+
self.fabric.barrier()
176+
157177
def print(self, *args, **kwargs):
158178
self.fabric.print(*args, **kwargs)
159179

@@ -165,35 +185,88 @@ def __call__(
165185
):
166186
dl = cycle(self.dataloader)
167187

188+
# while less than required number of training steps
189+
168190
while self.steps < self.num_train_steps:
169191

192+
self.model.train()
193+
194+
# gradient accumulation
195+
170196
for grad_accum_step in range(self.grad_accum_every):
171197
is_accumulating = grad_accum_step < (self.grad_accum_every - 1)
172198

173199
inputs = next(dl)
174200

175201
with self.fabric.no_backward_sync(self.model, enabled = is_accumulating):
202+
203+
# model forwards
204+
176205
loss, loss_breakdown = self.model(
177206
**inputs,
178207
return_loss_breakdown = True
179208
)
180209

210+
# backwards
211+
181212
self.fabric.backward(loss / self.grad_accum_every)
182213

214+
# log entire loss breakdown
215+
183216
self.log(**loss_breakdown._asdict())
184217

185218
self.print(f'loss: {loss.item():.3f}')
186219

220+
# clip gradients
221+
187222
self.fabric.clip_gradients(self.model, self.optimizer, max_norm = self.clip_grad_norm)
188223

224+
# optimizer step
225+
189226
self.optimizer.step()
190227

228+
# update exponential moving average
229+
230+
self.wait()
231+
191232
if self.is_main:
192233
self.ema_model.update()
193234

235+
self.wait()
236+
237+
# scheduler
238+
194239
self.scheduler.step()
195240
self.optimizer.zero_grad()
196241

197242
self.steps += 1
198243

244+
# maybe validate, for now, only on main with EMA model
245+
246+
if (
247+
self.is_main and
248+
self.needs_valid and
249+
divisible_by(self.steps, self.valid_every)
250+
):
251+
with torch.no_grad():
252+
self.ema_model.eval()
253+
254+
total_valid_loss = 0.
255+
256+
for valid_batch in self.valid_dataloader:
257+
valid_loss, valid_loss_breakdown = self.ema_model(
258+
**valid_batch,
259+
return_loss_breakdown = True
260+
)
261+
262+
valid_batch_size = valid_batch.get('atom_inputs').shape[0]
263+
scale = valid_batch_size / self.valid_dataset_size
264+
265+
scaled_valid_loss = valid_loss.item() * scale
266+
total_valid_loss += scaled_valid_loss
267+
268+
self.print(f'valid loss: {valid_loss.item():.3f}')
269+
270+
self.wait()
271+
199272
print(f'training complete')

tests/test_trainer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,19 @@
1313

1414
# mock dataset
1515

16-
class AtomDataset(Dataset):
16+
class MockAtomDataset(Dataset):
1717
def __init__(
1818
self,
19+
data_length,
1920
seq_len = 16,
2021
atoms_per_window = 27
2122
):
23+
self.data_length = data_length
2224
self.seq_len = seq_len
2325
self.atom_seq_len = seq_len * atoms_per_window
2426

2527
def __len__(self):
26-
return 100
28+
return self.data_length
2729

2830
def __getitem__(self, idx):
2931
seq_len = self.seq_len
@@ -93,14 +95,17 @@ def test_trainer():
9395
),
9496
)
9597

96-
dataset = AtomDataset()
98+
dataset = MockAtomDataset(100)
99+
valid_dataset = MockAtomDataset(2)
97100

98101
trainer = Trainer(
99102
alphafold3,
100103
dataset = dataset,
104+
valid_dataset = valid_dataset,
101105
accelerator = 'cpu',
102106
num_train_steps = 2,
103107
batch_size = 1,
108+
valid_every = 1,
104109
grad_accum_every = 2
105110
)
106111

0 commit comments

Comments
 (0)