@@ -2524,11 +2524,9 @@ def preconditioned_network_forward(
25242524
25252525 padded_sigma = rearrange (sigma , 'b -> b 1 1' )
25262526
2527- maybe_c_noise = self .c_noise if self .karras_formulation else identity
2528-
25292527 net_out = self .net (
25302528 self .c_in (padded_sigma ) * noised_atom_pos ,
2531- times = maybe_c_noise ( sigma ) ,
2529+ times = sigma ,
25322530 ** network_condition_kwargs
25332531 )
25342532
@@ -2615,7 +2613,7 @@ def sample(
26152613
26162614 # second order correction, if not the last timestep
26172615
2618- if sigma_next != 0 :
2616+ if self . karras_formulation and sigma_next != 0 :
26192617 model_output_next = self .preconditioned_network_forward (atom_pos_next , sigma_next , clamp = clamp , network_condition_kwargs = network_condition_kwargs )
26202618 denoised_prime_over_sigma = (atom_pos_next - model_output_next ) / sigma_next
26212619 atom_pos_next = atom_pos_hat + 0.5 * (sigma_next - sigma_hat ) * (denoised_over_sigma + denoised_prime_over_sigma ) * step_scale
@@ -2674,7 +2672,9 @@ def forward(
26742672
26752673 noise = torch .randn_like (atom_pos_ground_truth )
26762674
2677- noised_atom_pos = atom_pos_ground_truth + padded_sigmas * noise # alphas are 1. in the paper
2675+ maybe_c_noise = self .c_noise if not self .karras_formulation else identity # @wufandi claims the paper has a bug here https://github.com/lucidrains/alphafold3-pytorch/issues/124#issuecomment-2268374756
2676+
2677+ noised_atom_pos = atom_pos_ground_truth + padded_sigmas * maybe_c_noise (noise ) # alphas are 1. in the paper
26782678
26792679 denoised_atom_pos = self .preconditioned_network_forward (
26802680 noised_atom_pos ,
@@ -3941,20 +3941,20 @@ def get_cid_molecule_type(
39413941
39423942 return molecule_type
39433943
3944+ @typecheck
39443945def _protein_structure_from_feature (
39453946 asym_id : Int [' n' ],
39463947 molecule_ids : Int [' n' ],
39473948 molecule_atom_lens : Int [' n' ],
39483949 atom_pos : Float [' m 3' ],
3949- atom_mask : Bool [' m' ],):
3950+ atom_mask : Bool [' m' ],
3951+ ) -> Bio .PDB .Structure .Structure :
3952+
39503953 """
3951-
39523954 create structure for unresolved protein
39533955
39543956 atom_mask: True for valid atom, False for missing/padding atom
3955- return: Bio.PDB.Structure.Structure
39563957 """
3957-
39583958 num_atom = atom_pos .shape [0 ]
39593959 num_res = molecule_ids .shape [0 ]
39603960
@@ -4326,17 +4326,14 @@ def _compute_unresolved_rasa(
43264326 chain_molecule_ids = molecule_ids [chain_mask ]
43274327 chain_molecule_atom_lens = molecule_atom_lens [chain_mask ]
43284328
4329- chain_mask_to_atom = repeat_consecutive_with_lens (
4330- chain_mask .unsqueeze (0 ), molecule_atom_lens .unsqueeze (0 )
4331- ).squeeze (0 )
4329+ chain_mask_to_atom = torch .repeat_interleave (chain_mask , molecule_atom_lens )
43324330
43334331 # if there's padding in num atom
43344332 num_pad = num_atom - molecule_atom_lens .sum ()
43354333 if num_pad > 0 :
43364334 chain_mask_to_atom = F .pad (
43374335 chain_mask_to_atom , (0 , num_pad ), value = False )
43384336
4339-
43404337 chain_atom_pos = atom_pos [chain_mask_to_atom ]
43414338 chain_atom_mask = atom_mask [chain_mask_to_atom ]
43424339
0 commit comments