Skip to content

Commit b139001

Browse files
committed
update ema-pytorch
1 parent ca644de commit b139001

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

alphafold3_pytorch/trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,9 @@ def __init__(
222222
checkpoint_folder: str = './checkpoints',
223223
overwrite_checkpoints: bool = False,
224224
fabric_kwargs: dict = dict(),
225-
ema_kwargs: dict = dict()
225+
ema_kwargs: dict = dict(
226+
use_foreach = True
227+
)
226228
):
227229
super().__init__()
228230

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.47"
3+
version = "0.1.48"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }
@@ -27,7 +27,7 @@ dependencies = [
2727
"biopython>=1.83",
2828
"einops>=0.8.0",
2929
"einx>=0.2.2",
30-
"ema-pytorch>=0.4.8",
30+
"ema-pytorch>=0.5.0",
3131
"environs",
3232
"frame-averaging-pytorch>=0.0.18",
3333
"jaxtyping>=0.2.28",

tests/test_trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,12 @@ def test_trainer():
149149
valid_every = 1,
150150
grad_accum_every = 2,
151151
checkpoint_every = 1,
152-
overwrite_checkpoints = True
152+
overwrite_checkpoints = True,
153+
ema_kwargs = dict(
154+
use_foreach = True,
155+
update_after_step = 0,
156+
update_every = 1
157+
)
153158
)
154159

155160
trainer()

0 commit comments

Comments
 (0)