Skip to content

Commit 492b4b8

Browse files
authored
another einx set_at removal (#282)
* another einx set_at removal * fix an issue with hard validate atom indices
1 parent 6931bdc commit 492b4b8

File tree

2 files changed

+25
-14
lines changed

2 files changed

+25
-14
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def hard_validate_atom_indices_ascending(
264264

265265
# NOTE: this is a relaxed assumption, i.e., that if empty, all -1, or only one molecule, then it passes the test
266266

267-
if present_indices.numel() == 0 or present_indices.shape[-1] <= 1:
267+
if present_indices.numel() == 0 or present_indices.shape[0] <= 1:
268268
continue
269269

270270
difference = einx.subtract(
@@ -1134,7 +1134,6 @@ def molecule_lengthed_molecule_input_to_atom_input(
11341134
if mol_is_one_token_per_atom:
11351135
coordinates = []
11361136

1137-
has_bond = torch.zeros(num_atoms, num_atoms).bool()
11381137
bonds = mol.GetBonds()
11391138
num_bonds = len(bonds)
11401139

@@ -1150,6 +1149,8 @@ def molecule_lengthed_molecule_input_to_atom_input(
11501149
)
11511150

11521151
if num_bonds > 0:
1152+
has_bond = torch.zeros(num_atoms, num_atoms).bool()
1153+
11531154
coordinates = tensor(coordinates).long()
11541155

11551156
# has_bond = einx.set_at("[h w], c [2], c -> [h w]", has_bond, coordinates, updates)
@@ -1163,8 +1164,8 @@ def molecule_lengthed_molecule_input_to_atom_input(
11631164

11641165
# / ein.set_at
11651166

1166-
row_col_slice = slice(offset, offset + num_atoms)
1167-
token_bonds[row_col_slice, row_col_slice] = has_bond
1167+
row_col_slice = slice(offset, offset + num_atoms)
1168+
token_bonds[row_col_slice, row_col_slice] = has_bond
11681169

11691170
offset += num_atoms if mol_is_one_token_per_atom else 1
11701171

@@ -3572,13 +3573,14 @@ def pdb_input_to_molecule_input(
35723573
# construct ligand and modified polymer chain token bonds
35733574

35743575
coordinates = []
3575-
updates = []
35763576

35773577
ligand = molecules[ligand_offset]
35783578
num_atoms = ligand.GetNumAtoms()
3579-
has_bond = torch.zeros(num_atoms, num_atoms).bool()
35803579

3581-
for bond in ligand.GetBonds():
3580+
bonds = ligand.GetBonds()
3581+
num_bonds = len(bonds)
3582+
3583+
for bond in bonds:
35823584
atom_start_index = bond.GetBeginAtomIdx()
35833585
atom_end_index = bond.GetEndAtomIdx()
35843586

@@ -3589,15 +3591,24 @@ def pdb_input_to_molecule_input(
35893591
]
35903592
)
35913593

3592-
updates.extend([True, True])
3594+
if num_bonds > 0:
3595+
has_bond = torch.zeros(num_atoms, num_atoms).bool()
35933596

3594-
coordinates = tensor(coordinates).long()
3595-
updates = tensor(updates).bool()
3597+
coordinates = tensor(coordinates).long()
3598+
3599+
# has_bond = einx.set_at("[h w], c [2], c -> [h w]", has_bond, coordinates, updates)
3600+
3601+
has_bond_stride = tensor(has_bond.stride())
3602+
flattened_coordinates = (coordinates * has_bond_stride).sum(dim = -1)
3603+
packed_has_bond, unpack_has_bond = pack_one(has_bond, '*')
3604+
3605+
packed_has_bond[flattened_coordinates] = True
3606+
has_bond = unpack_has_bond(packed_has_bond, '*')
35963607

3597-
has_bond = einx.set_at("[h w], c [2], c -> [h w]", has_bond, coordinates, updates)
3608+
# / einx.set_at
35983609

3599-
row_col_slice = slice(polymer_offset, polymer_offset + num_atoms)
3600-
token_bonds[row_col_slice, row_col_slice] = has_bond
3610+
row_col_slice = slice(polymer_offset, polymer_offset + num_atoms)
3611+
token_bonds[row_col_slice, row_col_slice] = has_bond
36013612

36023613
polymer_offset += num_atoms
36033614
ligand_offset += 1

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.5.35"
3+
version = "0.5.36"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)