Skip to content

Commit d7b374c

Browse files
committed
add lion
1 parent 7d1025a commit d7b374c

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

alphafold3_pytorch/trainer.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from torch.utils.data import Sampler, Dataset, DataLoader as OrigDataLoader
4141
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
4242

43+
from lion_pytorch.foreach import Lion
4344
from adam_atan2_pytorch.foreach import AdamAtan2
4445

4546
from ema_pytorch import EMA
@@ -60,6 +61,9 @@ def default(v, d):
6061
def divisible_by(num, den):
6162
return (num % den) == 0
6263

64+
def at_most_one_of(*flags: bool) -> bool:
65+
assert sum([*map(int, flags)]) <= 1
66+
6367
def cycle(dataloader: DataLoader):
6468
while True:
6569
for batch in dataloader:
@@ -281,6 +285,7 @@ def __init__(
281285
use_foreach = True
282286
),
283287
use_adam_atan2: bool = False,
288+
use_lion: bool = False,
284289
use_torch_compile: bool = False
285290
):
286291
super().__init__()
@@ -325,13 +330,18 @@ def __init__(
325330
# optimizer
326331

327332
if not exists(optimizer):
328-
adam_klass = Adam
333+
optimizer_klass = Adam
334+
335+
assert at_most_one_of(use_adam_atan2, use_lion)
329336

330337
if use_adam_atan2:
331338
default_adam_kwargs.pop('eps', None)
332-
adam_klass = AdamAtan2
339+
optimizer_klass = AdamAtan2
340+
elif use_lion:
341+
default_adam_kwargs.pop('eps', None)
342+
optimizer_klass = Lion
333343

334-
optimizer = adam_klass(
344+
optimizer = optimizer_klass(
335345
model.parameters(),
336346
lr = lr,
337347
**default_adam_kwargs

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.71"
3+
version = "0.2.72"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }
@@ -31,6 +31,7 @@ dependencies = [
3131
"einx>=0.2.2",
3232
"ema-pytorch>=0.5.0",
3333
"environs",
34+
"lion-pytorch",
3435
"joblib",
3536
"gemmi>=0.6.6",
3637
"frame-averaging-pytorch>=0.0.18",

0 commit comments

Comments
 (0)