Skip to content

Commit a35d2cb

Browse files
committed
fix package name so it is actually importable
1 parent 7d4ec22 commit a35d2cb

File tree

3 files changed

+25
-4
lines changed

3 files changed

+25
-4
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,14 @@ docker run -v .:/data --gpus all -it af3
226226
url = {https://api.semanticscholar.org/CorpusID:268063190}
227227
}
228228
```
229+
230+
```bibtex
231+
@article{Puny2021FrameAF,
232+
title = {Frame Averaging for Invariant and Equivariant Network Design},
233+
author = {Omri Puny and Matan Atzmon and Heli Ben-Hamu and Edward James Smith and Ishan Misra and Aditya Grover and Yaron Lipman},
234+
journal = {ArXiv},
235+
year = {2021},
236+
volume = {abs/2110.03336},
237+
url = {https://api.semanticscholar.org/CorpusID:238419638}
238+
}
239+
```

alphafold3_pytorch/alphafold3.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Sequential,
1818
)
1919

20-
from typing import Literal, Tuple, NamedTuple
20+
from typing import Literal, Tuple, NamedTuple, Callable
2121

2222
from alphafold3_pytorch.typing import (
2323
Float,
@@ -37,6 +37,8 @@
3737
full_pairwise_repr_to_windowed
3838
)
3939

40+
from frame_averaging_pytorch import FrameAverage
41+
4042
from taylor_series_linear_attention import TaylorSeriesLinearAttn
4143

4244
import einx
@@ -2106,6 +2108,7 @@ def forward(
21062108
pairwise_trunk: Float['b n n dpt'],
21072109
pairwise_rel_pos_feats: Float['b n n dpr'],
21082110
molecule_atom_lens: Int['b n'],
2111+
frame_average_fn: Callable[[Float['b n 3']], Float['b n 3']] | None = None,
21092112
return_denoised_pos = False,
21102113
additional_molecule_feats: Float[f'b n {ADDITIONAL_MOLECULE_FEATS}'] | None = None,
21112114
add_smooth_lddt_loss = False,
@@ -2142,6 +2145,13 @@ def forward(
21422145
)
21432146
)
21442147

2148+
# frame average the denoised atom positions if needed
2149+
2150+
if exists(frame_average_fn):
2151+
denoised_atom_pos = frame_average_fn(denoised_atom_pos)
2152+
2153+
# total loss, for accumulating all auxiliary losses
2154+
21452155
total_loss = 0.
21462156

21472157
# if additional molecule feats is provided

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.16"
3+
version = "0.1.17"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }
@@ -29,7 +29,7 @@ dependencies = [
2929
"einx>=0.2.2",
3030
"ema-pytorch>=0.4.8",
3131
"environs",
32-
"frame-averaging-pytorch",
32+
"frame-averaging-pytorch>=0.0.17",
3333
"hydra-core",
3434
"jaxtyping>=0.2.28",
3535
"lightning>=2.2.5",
@@ -75,4 +75,4 @@ ignore = [
7575
allow-direct-references = true
7676

7777
[tool.hatch.build.targets.wheel]
78-
packages = ["alphafold3-pytorch"]
78+
packages = ["alphafold3_pytorch"]

0 commit comments

Comments
 (0)