Skip to content

Commit 41c58c3

Browse files
committed
line 3 in Algorithm 18
1 parent b2e4928 commit 41c58c3

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2407,8 +2407,10 @@ def __init__(
24072407
S_tmax = 50,
24082408
S_noise = 1.003,
24092409
step_scale = 1.5,
2410+
augment_during_sampling = True,
24102411
smooth_lddt_loss_kwargs: dict = dict(),
24112412
weighted_rigid_align_kwargs: dict = dict(),
2413+
centre_random_augmentation_kwargs: dict = dict(),
24122414
karras_formulation = False # use the original EDM formulation from Karras et al. Table 1 in https://arxiv.org/abs/2206.00364 - differences are that the noise and sampling schedules are scaled by sigma data, as well as loss weight adds the sigma data instead of multiply in denominator
24132415
):
24142416
super().__init__()
@@ -2433,6 +2435,11 @@ def __init__(
24332435
self.S_tmax = S_tmax
24342436
self.S_noise = S_noise
24352437

2438+
# centre random augmenter
2439+
2440+
self.augment_during_sampling = augment_during_sampling
2441+
self.centre_random_augmenter = CentreRandomAugmentation(**centre_random_augmentation_kwargs)
2442+
24362443
# weighted rigid align
24372444

24382445
self.weighted_rigid_align = WeightedRigidAlign(**weighted_rigid_align_kwargs)
@@ -2553,8 +2560,12 @@ def sample(
25532560

25542561
maybe_tqdm_wrapper = tqdm if use_tqdm_pbar else identity
25552562

2563+
maybe_augment_fn = self.centre_random_augmenter if self.augment_during_sampling else identity
2564+
25562565
for sigma, sigma_next, gamma in maybe_tqdm_wrapper(sigmas_and_gammas, desc = tqdm_pbar_title):
2557-
sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma))
2566+
sigma, sigma_next, gamma = tuple(t.item() for t in (sigma, sigma_next, gamma))
2567+
2568+
atom_pos = maybe_augment_fn(atom_pos)
25582569

25592570
eps = self.S_noise * torch.randn(shape, device = self.device) # stochastic sampling
25602571

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.80"
3+
version = "0.2.81"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)