Skip to content

Commit 3231cab

Browse files
committed
test the waters with stochastic frame averaging first before going any deeper
1 parent a35d2cb commit 3231cab

File tree

4 files changed

+58
-11
lines changed

4 files changed

+58
-11
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,3 +237,14 @@ docker run -v .:/data --gpus all -it af3
237237
url = {https://api.semanticscholar.org/CorpusID:238419638}
238238
}
239239
```
240+
241+
```bibtex
242+
@article{Duval2023FAENetFA,
243+
title = {FAENet: Frame Averaging Equivariant GNN for Materials Modeling},
244+
author = {Alexandre Duval and Victor Schmidt and Alex Hernandez Garcia and Santiago Miret and Fragkiskos D. Malliaros and Yoshua Bengio and David Rolnick},
245+
journal = {ArXiv},
246+
year = {2023},
247+
volume = {abs/2305.05577},
248+
url = {https://api.semanticscholar.org/CorpusID:258564608}
249+
}
250+
```

alphafold3_pytorch/alphafold3.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2108,7 +2108,6 @@ def forward(
21082108
pairwise_trunk: Float['b n n dpt'],
21092109
pairwise_rel_pos_feats: Float['b n n dpr'],
21102110
molecule_atom_lens: Int['b n'],
2111-
frame_average_fn: Callable[[Float['b n 3']], Float['b n 3']] | None = None,
21122111
return_denoised_pos = False,
21132112
additional_molecule_feats: Float[f'b n {ADDITIONAL_MOLECULE_FEATS}'] | None = None,
21142113
add_smooth_lddt_loss = False,
@@ -2145,11 +2144,6 @@ def forward(
21452144
)
21462145
)
21472146

2148-
# frame average the denoised atom positions if needed
2149-
2150-
if exists(frame_average_fn):
2151-
denoised_atom_pos = frame_average_fn(denoised_atom_pos)
2152-
21532147
# total loss, for accumulating all auxiliary losses
21542148

21552149
total_loss = 0.
@@ -2905,7 +2899,8 @@ def __init__(
29052899
S_tmax = 50,
29062900
S_noise = 1.003,
29072901
),
2908-
augment_kwargs: dict = dict()
2902+
augment_kwargs: dict = dict(),
2903+
stochastic_frame_average = False
29092904
):
29102905
super().__init__()
29112906

@@ -2918,6 +2913,18 @@ def __init__(
29182913
self.num_augmentations = diffusion_num_augmentations
29192914
self.augmenter = CentreRandomAugmentation(**augment_kwargs)
29202915

2916+
# stochastic frame averaging
2917+
# https://arxiv.org/abs/2305.05577
2918+
2919+
self.stochastic_frame_average = stochastic_frame_average
2920+
2921+
if stochastic_frame_average:
2922+
self.frame_average = FrameAverage(
2923+
dim = 3,
2924+
stochastic = True,
2925+
return_stochastic_as_augmented_pos = True
2926+
)
2927+
29212928
# input feature embedder
29222929

29232930
self.input_embedder = InputFeatureEmbedder(
@@ -3330,7 +3337,7 @@ def forward(
33303337

33313338
if calc_diffusion_loss:
33323339

3333-
num_augs = self.num_augmentations
3340+
num_augs = self.num_augmentations + int(self.stochastic_frame_average)
33343341

33353342
# take care of augmentation
33363343
# they did 48 during training, as the trunk did the heavy lifting
@@ -3378,8 +3385,25 @@ def forward(
33783385
)
33793386
)
33803387

3388+
# handle stochastic frame averaging
3389+
3390+
if self.stochastic_frame_average:
3391+
fa_atom_pos, atom_pos = atom_pos[:1], atom_pos[1:]
3392+
3393+
fa_atom_pos = self.frame_average(
3394+
fa_atom_pos,
3395+
frame_average_mask = atom_mask[:1]
3396+
)
3397+
3398+
# normal random augmentations, 48 times in paper
3399+
33813400
atom_pos = self.augmenter(atom_pos)
33823401

3402+
# concat back the stochastic frame averaged position
3403+
3404+
if self.stochastic_frame_average:
3405+
atom_pos = torch.cat((fa_atom_pos, atom_pos), dim = 0)
3406+
33833407
diffusion_loss, denoised_atom_pos, diffusion_loss_breakdown, _ = self.edm(
33843408
atom_pos,
33853409
additional_molecule_feats = additional_molecule_feats,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ dependencies = [
2929
"einx>=0.2.2",
3030
"ema-pytorch>=0.4.8",
3131
"environs",
32-
"frame-averaging-pytorch>=0.0.17",
32+
"frame-averaging-pytorch>=0.0.18",
3333
"hydra-core",
3434
"jaxtyping>=0.2.28",
3535
"lightning>=2.2.5",

tests/test_af3.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
atom_ref_pos_to_atompair_inputs
3232
)
3333

34+
def join(str, delimiter = ','):
35+
return delimiter.join(str)
36+
3437
def test_atom_ref_pos_to_atompair_inputs():
3538
atom_ref_pos = torch.randn(16, 3)
3639
atom_ref_space_uid = torch.ones(16).long()
@@ -400,9 +403,17 @@ def test_distogram_head():
400403

401404
logits = distogram_head(pairwise_repr)
402405

403-
@pytest.mark.parametrize('window_atompair_inputs', (True, False))
406+
@pytest.mark.parametrize(
407+
join([
408+
'window_atompair_inputs',
409+
'stochastic_frame_average'
410+
]), [
411+
(True, False),
412+
(True, False)
413+
])
404414
def test_alphafold3(
405-
window_atompair_inputs: bool
415+
window_atompair_inputs: bool,
416+
stochastic_frame_average: bool
406417
):
407418
seq_len = 16
408419
atoms_per_window = 27
@@ -457,6 +468,7 @@ def test_alphafold3(
457468
token_transformer_depth = 1,
458469
atom_decoder_depth = 1,
459470
),
471+
stochastic_frame_average = stochastic_frame_average
460472
)
461473

462474
loss, breakdown = alphafold3(

0 commit comments

Comments
 (0)