File tree Expand file tree Collapse file tree 2 files changed +13
-3
lines changed Expand file tree Collapse file tree 2 files changed +13
-3
lines changed Original file line number Diff line number Diff line change 3939from torch .utils .data import Sampler , Dataset , DataLoader as OrigDataLoader
4040from torch .optim .lr_scheduler import LambdaLR , LRScheduler
4141
42+ from adam_atan2_pytorch .foreach import AdamAtan2
43+
4244from ema_pytorch import EMA
4345
4446from lightning import Fabric
@@ -275,7 +277,8 @@ def __init__(
275277 use_ema : bool = True ,
276278 ema_kwargs : dict = dict (
277279 use_foreach = True
278- )
280+ ),
281+ use_adam_atan2 : bool = False
279282 ):
280283 super ().__init__ ()
281284
@@ -305,7 +308,13 @@ def __init__(
305308 # optimizer
306309
307310 if not exists (optimizer ):
308- optimizer = Adam (
311+ adam_klass = Adam
312+
313+ if use_adam_atan2 :
314+ del default_adam_kwargs ['eps' ]
315+ adam_klass = AdamAtan2
316+
317+ optimizer = adam_klass (
309318 model .parameters (),
310319 lr = lr ,
311320 ** default_adam_kwargs
Original file line number Diff line number Diff line change 11[project ]
22name = " alphafold3-pytorch"
3- version = " 0.2.53 "
3+ version = " 0.2.54 "
44description = " Alphafold 3 - Pytorch"
55authors = [
66 {
name =
" Phil Wang" ,
email =
" [email protected] " }
@@ -23,6 +23,7 @@ classifiers=[
2323]
2424
2525dependencies = [
26+ " adam-atan2-pytorch>=0.0.8" ,
2627 " beartype" ,
2728 " biopython>=1.83" ,
2829 " CoLT5-attention>=0.11.0" ,
You can’t perform that action at this time.
0 commit comments