66import os
77import random
88import statistics
9+ import traceback
910from collections import defaultdict
1011from collections .abc import Iterable
1112from contextlib import redirect_stderr
5152 rna_constants ,
5253 ligand_constants
5354)
55+
5456from alphafold3_pytorch .common .biomolecule import (
5557 Biomolecule ,
5658 _from_mmcif_object ,
5759 get_residue_constants ,
5860)
61+
5962from alphafold3_pytorch .data import (
6063 mmcif_parsing ,
6164 msa_pairing ,
6265 msa_parsing ,
6366 template_parsing ,
6467)
68+
6569from alphafold3_pytorch .data .data_pipeline import (
6670 FeatureDict ,
6771 get_assembly ,
7074 make_template_features ,
7175 merge_chain_features ,
7276)
77+
7378from alphafold3_pytorch .data .weighted_pdb_sampler import WeightedPDBSampler
79+
7480from alphafold3_pytorch .life import (
7581 ATOM_BONDS ,
7682 ATOMS ,
8288 reverse_complement ,
8389 reverse_complement_tensor ,
8490)
91+
8592from alphafold3_pytorch .utils .data_utils import (
8693 PDB_INPUT_RESIDUE_MOLECULE_TYPE ,
8794 extract_mmcif_metadata_field ,
9198 is_polymer ,
9299 make_one_hot ,
93100)
101+
94102from alphafold3_pytorch .utils .model_utils import (
95103 distance_to_dgram ,
96104 exclusive_cumsum ,
99107 offset_only_positive ,
100108 remove_consecutive_duplicate ,
101109 to_pairwise_mask ,
110+ pack_one
102111)
112+
103113from alphafold3_pytorch .tensor_typing import Bool , Float , Int , typecheck
114+
104115from alphafold3_pytorch .utils .utils import default , exists , first , not_exists
105116
106117from alphafold3_pytorch .attention import (
@@ -759,7 +770,10 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
759770 num_atoms = mol .GetNumAtoms ()
760771 mol_atompair_ids = torch .zeros (num_atoms , num_atoms ).long ()
761772
762- for bond in mol .GetBonds ():
773+ bonds = mol .GetBonds ()
774+ num_bonds = len (bonds )
775+
776+ for bond in has_bonds :
763777 atom_start_index = bond .GetBeginAtomIdx ()
764778 atom_end_index = bond .GetEndAtomIdx ()
765779
@@ -785,12 +799,21 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
785799
786800 updates .extend ([bond_to , bond_from ])
787801
788- coordinates = tensor (coordinates ).long ()
789- updates = tensor (updates ).long ()
802+ if num_bonds > 0 :
803+ coordinates = tensor (coordinates ).long ()
804+ updates = tensor (updates ).long ()
790805
791- mol_atompair_ids = einx .set_at (
792- "[h w], c [2], c -> [h w]" , mol_atompair_ids , coordinates , updates
793- )
806+ # mol_atompair_ids = einx.set_at("[h w], c [2], c -> [h w]", mol_atompair_ids, coordinates, updates)
807+
808+ molpair_strides = tensor (mol_atompair_ids .stride ())
809+ flattened_coordinates = (coordinates * molpair_strides ).sum (dim = - 1 )
810+
811+ packed_atompair_ids , unpack_one = pack_one (mol_atompair_ids , '*' )
812+ packed_atompair_ids [flattened_coordinates ] = updates
813+
814+ mol_atompair_ids = unpack_one (packed_atompair_ids )
815+
816+ # /einx.set_at
794817
795818 row_col_slice = slice (offset , offset + num_atoms )
796819 atompair_ids [row_col_slice , row_col_slice ] = mol_atompair_ids
@@ -1110,11 +1133,12 @@ def molecule_lengthed_molecule_input_to_atom_input(
11101133
11111134 if mol_is_one_token_per_atom :
11121135 coordinates = []
1113- updates = []
11141136
11151137 has_bond = torch .zeros (num_atoms , num_atoms ).bool ()
1138+ bonds = mol .GetBonds ()
1139+ num_bonds = len (bonds )
11161140
1117- for bond in mol . GetBonds () :
1141+ for bond in bonds :
11181142 atom_start_index = bond .GetBeginAtomIdx ()
11191143 atom_end_index = bond .GetEndAtomIdx ()
11201144
@@ -1125,12 +1149,19 @@ def molecule_lengthed_molecule_input_to_atom_input(
11251149 ]
11261150 )
11271151
1128- updates .extend ([True , True ])
1152+ if num_bonds > 0 :
1153+ coordinates = tensor (coordinates ).long ()
11291154
1130- coordinates = tensor (coordinates ).long ()
1131- updates = tensor (updates ).bool ()
1155+ # has_bond = einx.set_at("[h w], c [2], c -> [h w]", has_bond, coordinates, updates)
11321156
1133- has_bond = einx .set_at ("[h w], c [2], c -> [h w]" , has_bond , coordinates , updates )
1157+ has_bond_stride = tensor (has_bond .stride ())
1158+ flattened_coordinates = (coordinates * has_bond_stride ).sum (dim = - 1 )
1159+ packed_has_bond , unpack_has_bond = pack_one (has_bond , '*' )
1160+
1161+ packed_has_bond [flattened_coordinates ] = True
1162+ has_bond = unpack_has_bond (packed_has_bond , '*' )
1163+
1164+ # / ein.set_at
11341165
11351166 row_col_slice = slice (offset , offset + num_atoms )
11361167 token_bonds [row_col_slice , row_col_slice ] = has_bond
@@ -1279,9 +1310,7 @@ def molecule_lengthed_molecule_input_to_atom_input(
12791310 coordinates = tensor (coordinates ).long ()
12801311 updates = tensor (updates ).long ()
12811312
1282- mol_atompair_ids = einx .set_at (
1283- "[h w], c [2], c -> [h w]" , mol_atompair_ids , coordinates , updates
1284- )
1313+ # mol_atompair_ids = einx.set_at("[h w], c [2], c -> [h w]", mol_atompair_ids, coordinates, updates)
12851314
12861315 row_col_slice = slice (offset , offset + num_atoms )
12871316 atompair_ids [row_col_slice , row_col_slice ] = mol_atompair_ids
@@ -4638,7 +4667,7 @@ def maybe_transform_to_atom_input(i: Any, raise_exception: bool = False) -> Atom
46384667 try :
46394668 return maybe_to_atom_fn (i )
46404669 except Exception as e :
4641- logger .error (f"Failed to convert input { i } to AtomInput due to: { e } " )
4670+ logger .error (f"Failed to convert input { i } to AtomInput due to: { e } , { traceback . format_exc () } " )
46424671 if raise_exception :
46434672 raise e
46444673 return None
0 commit comments