Skip to content

Commit d271481

Browse files
authored
add distributed evaluation for fabric training, make sure it can be t… (#291)
1 parent e7aeed2 commit d271481

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

alphafold3_pytorch/trainer.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def __init__(
171171
checkpoint_folder: str = './checkpoints',
172172
overwrite_checkpoints: bool = False,
173173
fabric_kwargs: dict = dict(),
174+
distributed_eval: bool = True,
174175
fp16: bool = False,
175176
use_ema: bool = True,
176177
ema_kwargs: dict = dict(
@@ -201,10 +202,16 @@ def __init__(
201202
self.fabric = fabric
202203
fabric.launch()
203204

205+
# whether evaluating only on root node or not
206+
# to save on each machine keeping track of EMA
207+
208+
self.distributed_eval = distributed_eval
209+
self.will_eval_or_test = self.is_main or distributed_eval
210+
204211
# exponential moving average
205212

206213
self.ema_model = None
207-
self.has_ema = self.is_main and use_ema
214+
self.has_ema = self.will_eval_or_test and use_ema
208215

209216
if self.has_ema:
210217
self.ema_model = EMA(
@@ -282,16 +289,18 @@ def __init__(
282289
self.valid_every = valid_every
283290

284291
self.needs_valid = exists(valid_dataset)
292+
self.valid_dataloader = None
285293

286-
if self.needs_valid and self.is_main:
294+
if self.needs_valid and self.will_eval_or_test:
287295
self.valid_dataset_size = len(valid_dataset)
288296
self.valid_dataloader = DataLoader_(valid_dataset, batch_size = batch_size)
289297

290298
# testing dataloader on EMA model
291299

292300
self.needs_test = exists(test_dataset)
301+
self.test_dataloader = None
293302

294-
if self.needs_test and self.is_main:
303+
if self.needs_test and self.will_eval_or_test:
295304
self.test_dataset_size = len(test_dataset)
296305
self.test_dataloader = DataLoader_(test_dataset, batch_size = batch_size)
297306

@@ -306,6 +315,12 @@ def __init__(
306315

307316
fabric.setup_dataloaders(self.dataloader)
308317

318+
if exists(self.valid_dataloader) and self.distributed_eval:
319+
fabric.setup_dataloaders(self.valid_dataloader)
320+
321+
if exists(self.test_dataloader) and self.distributed_eval:
322+
fabric.setup_dataloaders(self.test_dataloader)
323+
309324
# scheduler
310325

311326
if not exists(scheduler):
@@ -555,7 +570,7 @@ def __call__(
555570
# maybe validate, for now, only on main with EMA model
556571

557572
if (
558-
self.is_main and
573+
self.will_eval_or_test and
559574
self.needs_valid and
560575
divisible_by(self.steps, self.valid_every)
561576
):
@@ -585,6 +600,11 @@ def __call__(
585600

586601
valid_loss_breakdown = {f'valid_{k}':v for k, v in valid_loss_breakdown.items()}
587602

603+
# reduce valid loss breakdown
604+
605+
if self.distributed_eval:
606+
valid_loss_breakdown = self.fabric.all_reduce(valid_loss_breakdown, reduce_op = 'sum')
607+
588608
# log
589609

590610
self.log(**valid_loss_breakdown)
@@ -598,7 +618,7 @@ def __call__(
598618

599619
# maybe test
600620

601-
if self.is_main and self.needs_test:
621+
if self.will_eval_or_test and self.needs_test:
602622
eval_model = default(self.ema_model, self.model)
603623

604624
with torch.no_grad(), to_device_and_back(eval_model, self.device):
@@ -625,6 +645,11 @@ def __call__(
625645

626646
test_loss_breakdown = {f'test_{k}':v for k, v in test_loss_breakdown.items()}
627647

648+
# reduce
649+
650+
if self.distributed_eval:
651+
test_loss_breakdown = self.fabric.all_reduce(test_loss_breakdown, reduce_op = 'sum')
652+
628653
# log
629654

630655
self.log(**test_loss_breakdown)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.5.41"
3+
version = "0.5.42"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)