Skip to content

Commit 1a45062

Browse files
committed
address #124 and #68
1 parent f317060 commit 1a45062

File tree

3 files changed

+26
-7
lines changed

3 files changed

+26
-7
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2407,7 +2407,8 @@ def __init__(
24072407
S_tmax = 50,
24082408
S_noise = 1.003,
24092409
smooth_lddt_loss_kwargs: dict = dict(),
2410-
weighted_rigid_align_kwargs: dict = dict()
2410+
weighted_rigid_align_kwargs: dict = dict(),
2411+
karras_formulation = False # use the original EDM formulation from Karras et al. Table 1 in https://arxiv.org/abs/2206.00364 - differences are that the noise and sampling schedules are scaled by sigma data, as well as loss weight adds the sigma data instead of multiply in denominator
24112412
):
24122413
super().__init__()
24132414
self.net = net
@@ -2440,6 +2441,10 @@ def __init__(
24402441

24412442
self.register_buffer('zero', torch.tensor(0.), persistent = False)
24422443

2444+
# whether to use original karras formulation or not
2445+
2446+
self.karras_formulation = karras_formulation
2447+
24432448
@property
24442449
def device(self):
24452450
return next(self.net.parameters()).device
@@ -2504,7 +2509,9 @@ def sample_schedule(self, num_sample_steps = None):
25042509
sigmas = (self.sigma_max ** inv_rho + steps / (N - 1) * (self.sigma_min ** inv_rho - self.sigma_max ** inv_rho)) ** self.rho
25052510

25062511
sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0.
2507-
return sigmas
2512+
2513+
scale = 1. if self.karras_formulation else self.sigma_data
2514+
return sigmas * scale
25082515

25092516
@torch.no_grad()
25102517
def sample(
@@ -2573,11 +2580,17 @@ def sample(
25732580

25742581
# training
25752582

2576-
def loss_weight(self, sigma):
2583+
def karras_loss_weight(self, sigma):
25772584
return (sigma ** 2 + self.sigma_data ** 2) * (sigma * self.sigma_data) ** -2
25782585

2586+
def loss_weight(self, sigma):
2587+
""" for some reason, in paper they add instead of multiply as in original paper """
2588+
return (sigma ** 2 + self.sigma_data ** 2) * (sigma + self.sigma_data) ** -2
2589+
25792590
def noise_distribution(self, batch_size):
2580-
return (self.P_mean + self.P_std * torch.randn((batch_size,), device = self.device)).exp()
2591+
scale = 1. if self.karras_formulation else self.sigma_data
2592+
2593+
return (self.P_mean + self.P_std * torch.randn((batch_size,), device = self.device)).exp() * scale
25812594

25822595
def forward(
25832596
self,
@@ -2672,7 +2685,9 @@ def forward(
26722685

26732686
# regular loss weight as defined in EDM paper
26742687

2675-
loss_weights = self.loss_weight(padded_sigmas)
2688+
loss_weight_fn = self.karras_loss_weight if self.karras_formulation else self.loss_weight
2689+
2690+
loss_weights = loss_weight_fn(padded_sigmas)
26762691

26772692
losses = losses * loss_weights
26782693

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.78"
3+
version = "0.2.79"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,10 @@ def test_sequence_local_attn():
285285
out = attn(atoms, attn_bias = attn_bias)
286286
assert out.shape == atoms.shape
287287

288-
def test_diffusion_module():
288+
@pytest.mark.parametrize('karras_formulation', (True, False))
289+
def test_diffusion_module(
290+
karras_formulation
291+
):
289292

290293
seq_len = 16
291294

@@ -338,6 +341,7 @@ def test_diffusion_module():
338341

339342
edm = ElucidatedAtomDiffusion(
340343
diffusion_module,
344+
karras_formulation = karras_formulation,
341345
num_sample_steps = 2
342346
)
343347

0 commit comments

Comments
 (0)