Skip to content

Commit 9d87f6a

Browse files
committed
default to scaling noise schedule and sampling steps with sigma_data, but keep karras formulation of loss weight
1 parent 8354dc6 commit 9d87f6a

File tree

2 files changed

+4
-7
lines changed

2 files changed

+4
-7
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2444,7 +2444,7 @@ def __init__(
24442444
smooth_lddt_loss_kwargs: dict = dict(),
24452445
weighted_rigid_align_kwargs: dict = dict(),
24462446
centre_random_augmentation_kwargs: dict = dict(),
2447-
karras_formulation = False # use the original EDM formulation from Karras et al. Table 1 in https://arxiv.org/abs/2206.00364 - differences are that the noise and sampling schedules are scaled by sigma data, as well as loss weight adds the sigma data instead of multiply in denominator
2447+
karras_formulation = True, # use the original EDM formulation from Karras et al. Table 1 in https://arxiv.org/abs/2206.00364 - differences are that the noise and sampling schedules are scaled by sigma data, as well as loss weight adds the sigma data instead of multiply in denominator
24482448
):
24492449
super().__init__()
24502450
self.net = net
@@ -2552,8 +2552,7 @@ def sample_schedule(self, num_sample_steps = None):
25522552

25532553
sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0.
25542554

2555-
scale = 1. if self.karras_formulation else self.sigma_data
2556-
return sigmas * scale
2555+
return sigmas * self.sigma_data
25572556

25582557
@torch.no_grad()
25592558
def sample(
@@ -2634,9 +2633,7 @@ def loss_weight(self, sigma):
26342633
return (sigma ** 2 + self.sigma_data ** 2) * (sigma + self.sigma_data) ** -2
26352634

26362635
def noise_distribution(self, batch_size):
2637-
scale = 1. if self.karras_formulation else self.sigma_data
2638-
2639-
return (self.P_mean + self.P_std * torch.randn((batch_size,), device = self.device)).exp() * scale
2636+
return (self.P_mean + self.P_std * torch.randn((batch_size,), device = self.device)).exp() * self.sigma_data
26402637

26412638
def forward(
26422639
self,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.120"
3+
version = "0.2.121"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)