Skip to content

Commit b2e4928

Browse files
committed
address missing step scale in Algorithm 18, thanks to @wufandi
1 parent 1a45062 commit b2e4928

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2406,6 +2406,7 @@ def __init__(
24062406
S_tmin = 0.05,
24072407
S_tmax = 50,
24082408
S_noise = 1.003,
2409+
step_scale = 1.5,
24092410
smooth_lddt_loss_kwargs: dict = dict(),
24102411
weighted_rigid_align_kwargs: dict = dict(),
24112412
karras_formulation = False # use the original EDM formulation from Karras et al. Table 1 in https://arxiv.org/abs/2206.00364 - differences are that the noise and sampling schedules are scaled by sigma data, as well as loss weight adds the sigma data instead of multiply in denominator
@@ -2425,6 +2426,7 @@ def __init__(
24252426
self.P_std = P_std
24262427

24272428
self.num_sample_steps = num_sample_steps # otherwise known as N in the paper
2429+
self.step_scale = step_scale
24282430

24292431
self.S_churn = S_churn
24302432
self.S_tmin = S_tmin
@@ -2523,7 +2525,7 @@ def sample(
25232525
tqdm_pbar_title = 'sampling time step',
25242526
**network_condition_kwargs
25252527
):
2526-
num_sample_steps = default(num_sample_steps, self.num_sample_steps)
2528+
step_scale, num_sample_steps = self.step_scale, default(num_sample_steps, self.num_sample_steps)
25272529

25282530
shape = (*atom_mask.shape, 3)
25292531

@@ -2562,14 +2564,14 @@ def sample(
25622564
model_output = self.preconditioned_network_forward(atom_pos_hat, sigma_hat, clamp = clamp, network_condition_kwargs = network_condition_kwargs)
25632565
denoised_over_sigma = (atom_pos_hat - model_output) / sigma_hat
25642566

2565-
atom_pos_next = atom_pos_hat + (sigma_next - sigma_hat) * denoised_over_sigma
2567+
atom_pos_next = atom_pos_hat + (sigma_next - sigma_hat) * denoised_over_sigma * step_scale
25662568

25672569
# second order correction, if not the last timestep
25682570

25692571
if sigma_next != 0:
25702572
model_output_next = self.preconditioned_network_forward(atom_pos_next, sigma_next, clamp = clamp, network_condition_kwargs = network_condition_kwargs)
25712573
denoised_prime_over_sigma = (atom_pos_next - model_output_next) / sigma_next
2572-
atom_pos_next = atom_pos_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma)
2574+
atom_pos_next = atom_pos_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma) * step_scale
25732575

25742576
atom_pos = atom_pos_next
25752577

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.79"
3+
version = "0.2.80"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)