Skip to content

Commit 340dd87

Browse files
committed
some cleanup to trainer
1 parent c59b6ca commit 340dd87

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

alphafold3_pytorch/trainer.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -146,18 +146,26 @@ def __init__(
146146

147147
self.clip_grad_norm = clip_grad_norm
148148

149+
# steps
150+
151+
self.steps = 0
152+
149153
@property
150154
def is_main(self):
151155
return self.fabric.global_rank == 0
152156

157+
def print(self, *args, **kwargs):
158+
self.fabric.print(*args, **kwargs)
159+
160+
def log(self, **log_data):
161+
self.fabric.log_dict(log_data, step = self.steps)
162+
153163
def __call__(
154164
self
155165
):
156-
dl = iter(self.dataloader)
157-
158-
steps = 0
166+
dl = cycle(self.dataloader)
159167

160-
while steps < self.num_train_steps:
168+
while self.steps < self.num_train_steps:
161169

162170
for grad_accum_step in range(self.grad_accum_every):
163171
is_accumulating = grad_accum_step < (self.grad_accum_every - 1)
@@ -169,7 +177,9 @@ def __call__(
169177

170178
self.fabric.backward(loss / self.grad_accum_every)
171179

172-
print(f'loss: {loss.item():.3f}')
180+
self.log(loss = loss)
181+
182+
self.print(f'loss: {loss.item():.3f}')
173183

174184
self.fabric.clip_gradients(self.model, self.optimizer, max_norm = self.clip_grad_norm)
175185

@@ -181,6 +191,6 @@ def __call__(
181191
self.scheduler.step()
182192
self.optimizer.zero_grad()
183193

184-
steps += 1
194+
self.steps += 1
185195

186196
print(f'training complete')

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.44"
3+
version = "0.0.45"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)