Skip to content

Commit 9e88ac5

Browse files
committed
torch compile flag in trainer
1 parent 1f89c62 commit 9e88ac5

File tree

3 files changed

+15
-6
lines changed

3 files changed

+15
-6
lines changed

alphafold3_pytorch/tensor_typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,6 @@ def __getitem__(self, shapes: str):
6969
Int,
7070
Bool,
7171
typecheck,
72+
should_typecheck,
7273
beartype_isinstance
7374
]

alphafold3_pytorch/trainer.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import TypedDict, List, Callable
1111

1212
from alphafold3_pytorch.tensor_typing import (
13+
should_typecheck,
1314
typecheck,
1415
Int, Bool, Float
1516
)
@@ -278,7 +279,8 @@ def __init__(
278279
ema_kwargs: dict = dict(
279280
use_foreach = True
280281
),
281-
use_adam_atan2: bool = False
282+
use_adam_atan2: bool = False,
283+
use_torch_compile: bool = False
282284
):
283285
super().__init__()
284286

@@ -288,10 +290,6 @@ def __init__(
288290
self.fabric = fabric
289291
fabric.launch()
290292

291-
# model
292-
293-
self.model = model
294-
295293
# exponential moving average
296294

297295
self.ema_model = None
@@ -305,6 +303,16 @@ def __init__(
305303
**ema_kwargs
306304
)
307305

306+
# maybe torch compile
307+
308+
if use_torch_compile:
309+
assert not should_typecheck, f'does not work well with jaxtyping + beartype, please invoke your training script with the environment flag `TYPECHECK=False` - ex. `TYPECHECK=False python train_af3.py`'
310+
model = torch.compile(model)
311+
312+
# model
313+
314+
self.model = model
315+
308316
# optimizer
309317

310318
if not exists(optimizer):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.55"
3+
version = "0.2.57"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)