Skip to content

Commit 4532883

Browse files
committed
setup fabric
1 parent f73bef5 commit 4532883

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

alphafold3_pytorch/trainer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from ema_pytorch import EMA
1111

12+
from lightning import Fabric
1213

1314
# helpers
1415

@@ -49,9 +50,21 @@ def __init__(
4950
),
5051
clip_grad_norm = 10.,
5152
default_lambda_lr = default_lambda_lr_fn,
53+
fabric: Fabric | None = None,
54+
accelerator = 'auto',
55+
fabric_kwargs: dict = dict(),
5256
ema_kwargs: dict = dict()
5357
):
5458
super().__init__()
59+
60+
if not exists(fabric):
61+
fabric = Fabric(accelerator = accelerator, **fabric_kwargs)
62+
63+
self.fabric = fabric
64+
fabric.launch()
65+
66+
# model
67+
5568
self.model = model
5669

5770
# exponential moving average
@@ -74,6 +87,10 @@ def __init__(
7487

7588
self.optimizer = optimizer
7689

90+
# setup fabric
91+
92+
self.model, self.optimizer = fabric.setup(self.model, self.optimizer)
93+
7794
# scheduler
7895

7996
if not exists(scheduler):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ dependencies = [
2929
"ema-pytorch>=0.4.8",
3030
"environs",
3131
"jaxtyping>=0.2.28",
32-
"pytorch-lightning>=2.2.5",
32+
"lightning>=2.2.5",
3333
"taylor-series-linear-attention>=0.1.9",
3434
"torch>=2.1",
3535
"tqdm",

0 commit comments

Comments
 (0)