@@ -264,7 +264,7 @@ def hard_validate_atom_indices_ascending(
264264
265265 # NOTE: this is a relaxed assumption, i.e., that if empty, all -1, or only one molecule, then it passes the test
266266
267- if present_indices .numel () == 0 or present_indices .shape [- 1 ] <= 1 :
267+ if present_indices .numel () == 0 or present_indices .shape [0 ] <= 1 :
268268 continue
269269
270270 difference = einx .subtract (
@@ -1134,7 +1134,6 @@ def molecule_lengthed_molecule_input_to_atom_input(
11341134 if mol_is_one_token_per_atom :
11351135 coordinates = []
11361136
1137- has_bond = torch .zeros (num_atoms , num_atoms ).bool ()
11381137 bonds = mol .GetBonds ()
11391138 num_bonds = len (bonds )
11401139
@@ -1150,6 +1149,8 @@ def molecule_lengthed_molecule_input_to_atom_input(
11501149 )
11511150
11521151 if num_bonds > 0 :
1152+ has_bond = torch .zeros (num_atoms , num_atoms ).bool ()
1153+
11531154 coordinates = tensor (coordinates ).long ()
11541155
11551156 # has_bond = einx.set_at("[h w], c [2], c -> [h w]", has_bond, coordinates, updates)
@@ -1163,8 +1164,8 @@ def molecule_lengthed_molecule_input_to_atom_input(
11631164
11641165 # / ein.set_at
11651166
1166- row_col_slice = slice (offset , offset + num_atoms )
1167- token_bonds [row_col_slice , row_col_slice ] = has_bond
1167+ row_col_slice = slice (offset , offset + num_atoms )
1168+ token_bonds [row_col_slice , row_col_slice ] = has_bond
11681169
11691170 offset += num_atoms if mol_is_one_token_per_atom else 1
11701171
@@ -3572,13 +3573,14 @@ def pdb_input_to_molecule_input(
35723573 # construct ligand and modified polymer chain token bonds
35733574
35743575 coordinates = []
3575- updates = []
35763576
35773577 ligand = molecules [ligand_offset ]
35783578 num_atoms = ligand .GetNumAtoms ()
3579- has_bond = torch .zeros (num_atoms , num_atoms ).bool ()
35803579
3581- for bond in ligand .GetBonds ():
3580+ bonds = ligand .GetBonds ()
3581+ num_bonds = len (bonds )
3582+
3583+ for bond in bonds :
35823584 atom_start_index = bond .GetBeginAtomIdx ()
35833585 atom_end_index = bond .GetEndAtomIdx ()
35843586
@@ -3589,15 +3591,24 @@ def pdb_input_to_molecule_input(
35893591 ]
35903592 )
35913593
3592- updates .extend ([True , True ])
3594+ if num_bonds > 0 :
3595+ has_bond = torch .zeros (num_atoms , num_atoms ).bool ()
35933596
3594- coordinates = tensor (coordinates ).long ()
3595- updates = tensor (updates ).bool ()
3597+ coordinates = tensor (coordinates ).long ()
3598+
3599+ # has_bond = einx.set_at("[h w], c [2], c -> [h w]", has_bond, coordinates, updates)
3600+
3601+ has_bond_stride = tensor (has_bond .stride ())
3602+ flattened_coordinates = (coordinates * has_bond_stride ).sum (dim = - 1 )
3603+ packed_has_bond , unpack_has_bond = pack_one (has_bond , '*' )
3604+
3605+ packed_has_bond [flattened_coordinates ] = True
3606+ has_bond = unpack_has_bond (packed_has_bond , '*' )
35963607
3597- has_bond = einx . set_at ( "[h w], c [2], c -> [h w]" , has_bond , coordinates , updates )
3608+ # / einx.set_at
35983609
3599- row_col_slice = slice (polymer_offset , polymer_offset + num_atoms )
3600- token_bonds [row_col_slice , row_col_slice ] = has_bond
3610+ row_col_slice = slice (polymer_offset , polymer_offset + num_atoms )
3611+ token_bonds [row_col_slice , row_col_slice ] = has_bond
36013612
36023613 polymer_offset += num_atoms
36033614 ligand_offset += 1
0 commit comments