Skip to content

Commit 5d0aae6

Browse files
committed
allow for using AdamAtan2 and get rid of the epsilon
1 parent ee68e6d commit 5d0aae6

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

alphafold3_pytorch/trainer.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
from torch.utils.data import Sampler, Dataset, DataLoader as OrigDataLoader
4040
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
4141

42+
from adam_atan2_pytorch.foreach import AdamAtan2
43+
4244
from ema_pytorch import EMA
4345

4446
from lightning import Fabric
@@ -275,7 +277,8 @@ def __init__(
275277
use_ema: bool = True,
276278
ema_kwargs: dict = dict(
277279
use_foreach = True
278-
)
280+
),
281+
use_adam_atan2: bool = False
279282
):
280283
super().__init__()
281284

@@ -305,7 +308,13 @@ def __init__(
305308
# optimizer
306309

307310
if not exists(optimizer):
308-
optimizer = Adam(
311+
adam_klass = Adam
312+
313+
if use_adam_atan2:
314+
del default_adam_kwargs['eps']
315+
adam_klass = AdamAtan2
316+
317+
optimizer = adam_klass(
309318
model.parameters(),
310319
lr = lr,
311320
**default_adam_kwargs

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.53"
3+
version = "0.2.54"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }
@@ -23,6 +23,7 @@ classifiers=[
2323
]
2424

2525
dependencies = [
26+
"adam-atan2-pytorch>=0.0.8",
2627
"beartype",
2728
"biopython>=1.83",
2829
"CoLT5-attention>=0.11.0",

0 commit comments

Comments
 (0)