Skip to content

Commit 39d282f

Browse files
committed
a new paper claims there is a free lunch by setting model weights to ema weights every epoch. allow researchers to experiment with this, conveniently already available in EMA-pytorch due to hare and tortoise paper
1 parent 4768a65 commit 39d282f

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,3 +483,14 @@ docker run -v .:/data --gpus all -it af3
483483
journal = {bioRxiv}
484484
}
485485
```
486+
487+
```bibtex
488+
@article{Li2024SwitchEA,
489+
title = {Switch EMA: A Free Lunch for Better Flatness and Sharpness},
490+
author = {Siyuan Li and Zicheng Liu and Juanxi Tian and Ge Wang and Zedong Wang and Weiyang Jin and Di Wu and Cheng Tan and Tao Lin and Yang Liu and Baigui Sun and Stan Z. Li},
491+
journal = {ArXiv},
492+
year = {2024},
493+
volume = {abs/2402.09240},
494+
url = {https://api.semanticscholar.org/CorpusID:267657558}
495+
}
496+
```

alphafold3_pytorch/trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def __init__(
178178
use_foreach = True
179179
),
180180
ema_on_cpu = False,
181+
ema_update_model_with_ema_every: int | None = None,
181182
use_adam_atan2: bool = False,
182183
use_lion: bool = False,
183184
use_torch_compile: bool = False
@@ -220,6 +221,7 @@ def __init__(
220221
include_online_model = False,
221222
allow_different_devices = True,
222223
coerce_dtype = True,
224+
update_model_with_ema_every = ema_update_model_with_ema_every,
223225
**ema_kwargs
224226
)
225227

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.6.2"
3+
version = "0.6.3"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },
@@ -33,7 +33,7 @@ dependencies = [
3333
"CoLT5-attention>=0.11.0",
3434
"einops>=0.8.0",
3535
"einx>=0.2.2",
36-
"ema-pytorch>=0.6.4",
36+
"ema-pytorch>=0.7.0",
3737
"environs",
3838
"lion-pytorch>=0.2.2",
3939
"joblib",

0 commit comments

Comments
 (0)