@@ -7407,6 +7407,8 @@ def forward(
74077407
74087408 denoised_molecule_pos = denoised_atom_pos .gather (1 , distogram_atom_coords_indices )
74097409
7410+ # get frames atom positions
7411+
74107412 # three_atoms = einx.get_at('b [m] c, b n three -> three b n c', atom_pos, atom_indices_for_frame)
74117413 # pred_three_atoms = einx.get_at('b [m] c, b n three -> three b n c', denoised_atom_pos, atom_indices_for_frame)
74127414
@@ -7421,10 +7423,8 @@ def forward(
74217423 three_atoms = three_atom_pos .gather (2 , atom_indices_for_frame )
74227424 pred_three_atoms = three_denoised_atom_pos .gather (2 , atom_indices_for_frame )
74237425
7424- # compute frames
7425-
7426- frames , _ = self .rigid_from_three_points (three_atoms )
7427- pred_frames , _ = self .rigid_from_three_points (pred_three_atoms )
7426+ frame_atoms = rearrange (three_atoms , "three b n c -> b n c three" )
7427+ pred_frame_atoms = rearrange (pred_three_atoms , "three b n c -> b n c three" )
74287428
74297429 # determine mask
74307430 # must be amino acid, nucleotide, or ligand with greater than 0 atoms
@@ -7436,8 +7436,8 @@ def forward(
74367436 align_error = self .compute_alignment_error (
74377437 denoised_molecule_pos ,
74387438 molecule_pos ,
7439- pred_frames ,
7440- frames ,
7439+ pred_frame_atoms , # In the paragraph 2 of section 4.3.2, the Phi_i denotes the coordinates of these frame atoms rather than the rotation matrix.
7440+ frame_atoms ,
74417441 mask = align_error_mask ,
74427442 )
74437443
0 commit comments