Skip to content

Commit b881335

Browse files
committed
address #124
1 parent a70c6d3 commit b881335

File tree

3 files changed

+12
-14
lines changed

3 files changed

+12
-14
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2524,11 +2524,9 @@ def preconditioned_network_forward(
25242524

25252525
padded_sigma = rearrange(sigma, 'b -> b 1 1')
25262526

2527-
maybe_c_noise = self.c_noise if self.karras_formulation else identity
2528-
25292527
net_out = self.net(
25302528
self.c_in(padded_sigma) * noised_atom_pos,
2531-
times = maybe_c_noise(sigma),
2529+
times = sigma,
25322530
**network_condition_kwargs
25332531
)
25342532

@@ -2615,7 +2613,7 @@ def sample(
26152613

26162614
# second order correction, if not the last timestep
26172615

2618-
if sigma_next != 0:
2616+
if self.karras_formulation and sigma_next != 0:
26192617
model_output_next = self.preconditioned_network_forward(atom_pos_next, sigma_next, clamp = clamp, network_condition_kwargs = network_condition_kwargs)
26202618
denoised_prime_over_sigma = (atom_pos_next - model_output_next) / sigma_next
26212619
atom_pos_next = atom_pos_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma) * step_scale
@@ -2674,7 +2672,9 @@ def forward(
26742672

26752673
noise = torch.randn_like(atom_pos_ground_truth)
26762674

2677-
noised_atom_pos = atom_pos_ground_truth + padded_sigmas * noise # alphas are 1. in the paper
2675+
maybe_c_noise = self.c_noise if not self.karras_formulation else identity # @wufandi claims the paper has a bug here https://github.com/lucidrains/alphafold3-pytorch/issues/124#issuecomment-2268374756
2676+
2677+
noised_atom_pos = atom_pos_ground_truth + padded_sigmas * maybe_c_noise(noise) # alphas are 1. in the paper
26782678

26792679
denoised_atom_pos = self.preconditioned_network_forward(
26802680
noised_atom_pos,
@@ -3941,20 +3941,20 @@ def get_cid_molecule_type(
39413941

39423942
return molecule_type
39433943

3944+
@typecheck
39443945
def _protein_structure_from_feature(
39453946
asym_id: Int[' n'],
39463947
molecule_ids: Int[' n'],
39473948
molecule_atom_lens: Int[' n'],
39483949
atom_pos: Float[' m 3'],
3949-
atom_mask: Bool[' m'],):
3950+
atom_mask: Bool[' m'],
3951+
) -> Bio.PDB.Structure.Structure:
3952+
39503953
"""
3951-
39523954
create structure for unresolved protein
39533955
39543956
atom_mask: True for valid atom, False for missing/padding atom
3955-
return: Bio.PDB.Structure.Structure
39563957
"""
3957-
39583958
num_atom = atom_pos.shape[0]
39593959
num_res = molecule_ids.shape[0]
39603960

@@ -4326,17 +4326,14 @@ def _compute_unresolved_rasa(
43264326
chain_molecule_ids = molecule_ids[chain_mask]
43274327
chain_molecule_atom_lens = molecule_atom_lens[chain_mask]
43284328

4329-
chain_mask_to_atom = repeat_consecutive_with_lens(
4330-
chain_mask.unsqueeze(0), molecule_atom_lens.unsqueeze(0)
4331-
).squeeze(0)
4329+
chain_mask_to_atom = torch.repeat_interleave(chain_mask, molecule_atom_lens)
43324330

43334331
# if there's padding in num atom
43344332
num_pad = num_atom - molecule_atom_lens.sum()
43354333
if num_pad > 0:
43364334
chain_mask_to_atom = F.pad(
43374335
chain_mask_to_atom, (0, num_pad), value = False)
43384336

4339-
43404337
chain_atom_pos = atom_pos[chain_mask_to_atom]
43414338
chain_atom_mask = atom_mask[chain_mask_to_atom]
43424339

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.97"
3+
version = "0.2.100"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,6 +1137,7 @@ def test_unresolved_protein_rasa():
11371137
unresolved_residue_mask = torch.randint(0, 2, asym_id.shape).bool()
11381138

11391139
compute_model_selection_score = ComputeModelSelectionScore()
1140+
11401141
unresolved_rasa = compute_model_selection_score.compute_unresolved_rasa(
11411142
unresolved_cid=[1],
11421143
unresolved_residue_mask=unresolved_residue_mask,

0 commit comments

Comments
 (0)