1212from typing import Any , Callable , List , Literal , Set , Tuple , Type
1313
1414import einx
15+ from einops import pack
1516
1617import numpy as np
1718from numpy .lib .format import open_memmap
5455 reverse_complement_tensor ,
5556)
5657
57-
5858from alphafold3_pytorch .tensor_typing import Bool , Float , Int , typecheck
5959from alphafold3_pytorch .utils .data_utils import RESIDUE_MOLECULE_TYPE , get_residue_molecule_type
6060from alphafold3_pytorch .utils .model_utils import exclusive_cumsum
@@ -298,6 +298,51 @@ def __getitem__(self, idx: int) -> AtomInput:
298298 file = self .files [idx ]
299299 return file_to_atom_input (file )
300300
301+ # atom reference position to atompair inputs
302+ # will be used in the `default_extract_atompair_feats_fn` below in MoleculeInput
303+
304+ @typecheck
305+ def atom_ref_pos_to_atompair_inputs (
306+ atom_ref_pos : Float ['m 3' ],
307+ atom_ref_space_uid : Int ['m' ] | None = None ,
308+ ) -> Float ['m m 5' ]:
309+
310+ # Algorithm 5 - lines 2-6
311+
312+ # line 2
313+
314+ pairwise_rel_pos = einx .subtract ('i c, j c -> i j c' , atom_ref_pos , atom_ref_pos )
315+
316+ # line 5 - pairwise inverse squared distance
317+
318+ atom_inv_square_dist = (1 + pairwise_rel_pos .norm (dim = - 1 , p = 2 ) ** 2 ) ** - 1
319+
320+ # line 3
321+
322+ if exists (atom_ref_space_uid ):
323+ same_ref_space_mask = einx .equal ('i, j -> i j' , atom_ref_space_uid , atom_ref_space_uid )
324+ else :
325+ same_ref_space_mask = torch .ones_like (atom_inv_square_dist ).bool ()
326+
327+ # concat all into atompair_inputs for projection into atompair_feats within Alphafold3
328+
329+ atompair_inputs , _ = pack ((
330+ pairwise_rel_pos ,
331+ atom_inv_square_dist ,
332+ same_ref_space_mask .float (),
333+ ), 'i j *' )
334+
335+ # mask out
336+
337+ atompair_inputs = einx .where (
338+ 'i j, i j dapi, -> i j dapi' ,
339+ same_ref_space_mask , atompair_inputs , 0.
340+ )
341+
342+ # return
343+
344+ return atompair_inputs
345+
301346# molecule input - accepting list of molecules as rdchem.Mol + the atomic lengths for how to pool into tokens
302347
303348def default_extract_atom_feats_fn (atom : Atom ):
@@ -316,8 +361,7 @@ def default_extract_atompair_feats_fn(mol: Mol):
316361
317362 all_atom_pos_tensor = tensor (all_atom_pos )
318363
319- dist_matrix = torch .cdist (all_atom_pos_tensor , all_atom_pos_tensor )
320- return torch .stack ((dist_matrix ,), dim = - 1 )
364+ return atom_ref_pos_to_atompair_inputs (all_atom_pos_tensor ) # what they did in the paper, but can be overwritten
321365
322366@typecheck
323367@dataclass
0 commit comments