Skip to content

Commit be1ea03

Browse files
committed
einx indexing has performance issues
1 parent e99cfb4 commit be1ea03

File tree

4 files changed

+50
-28
lines changed

4 files changed

+50
-28
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@
104104
ExpressCoordinatesInFrame,
105105
RigidFrom3Points,
106106
RigidFromReference3Points,
107-
calculate_weighted_rigid_align_weights
107+
calculate_weighted_rigid_align_weights,
108+
pack_one
108109
)
109110

110111
from alphafold3_pytorch.utils.model_utils import distance_to_dgram
@@ -250,15 +251,6 @@ def max_neg_value(t: Tensor):
250251
def dict_to_device(d, device):
251252
return tree_map(lambda t: t.to(device) if is_tensor(t) else t, d)
252253

253-
def pack_one(t, pattern):
254-
packed, ps = pack([t], pattern)
255-
256-
def unpack_one(to_unpack, unpack_pattern = None):
257-
unpacked, = unpack(to_unpack, ps, default(unpack_pattern, pattern))
258-
return unpacked
259-
260-
return packed, unpack_one
261-
262254
def exclusive_cumsum(t, dim = -1):
263255
return t.cumsum(dim = dim) - t
264256

alphafold3_pytorch/inputs.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import random
88
import statistics
9+
import traceback
910
from collections import defaultdict
1011
from collections.abc import Iterable
1112
from contextlib import redirect_stderr
@@ -51,17 +52,20 @@
5152
rna_constants,
5253
ligand_constants
5354
)
55+
5456
from alphafold3_pytorch.common.biomolecule import (
5557
Biomolecule,
5658
_from_mmcif_object,
5759
get_residue_constants,
5860
)
61+
5962
from alphafold3_pytorch.data import (
6063
mmcif_parsing,
6164
msa_pairing,
6265
msa_parsing,
6366
template_parsing,
6467
)
68+
6569
from alphafold3_pytorch.data.data_pipeline import (
6670
FeatureDict,
6771
get_assembly,
@@ -70,7 +74,9 @@
7074
make_template_features,
7175
merge_chain_features,
7276
)
77+
7378
from alphafold3_pytorch.data.weighted_pdb_sampler import WeightedPDBSampler
79+
7480
from alphafold3_pytorch.life import (
7581
ATOM_BONDS,
7682
ATOMS,
@@ -82,6 +88,7 @@
8288
reverse_complement,
8389
reverse_complement_tensor,
8490
)
91+
8592
from alphafold3_pytorch.utils.data_utils import (
8693
PDB_INPUT_RESIDUE_MOLECULE_TYPE,
8794
extract_mmcif_metadata_field,
@@ -91,6 +98,7 @@
9198
is_polymer,
9299
make_one_hot,
93100
)
101+
94102
from alphafold3_pytorch.utils.model_utils import (
95103
distance_to_dgram,
96104
exclusive_cumsum,
@@ -99,8 +107,11 @@
99107
offset_only_positive,
100108
remove_consecutive_duplicate,
101109
to_pairwise_mask,
110+
pack_one
102111
)
112+
103113
from alphafold3_pytorch.tensor_typing import Bool, Float, Int, typecheck
114+
104115
from alphafold3_pytorch.utils.utils import default, exists, first, not_exists
105116

106117
from alphafold3_pytorch.attention import (
@@ -759,7 +770,10 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
759770
num_atoms = mol.GetNumAtoms()
760771
mol_atompair_ids = torch.zeros(num_atoms, num_atoms).long()
761772

762-
for bond in mol.GetBonds():
773+
bonds = mol.GetBonds()
774+
num_bonds = len(bonds)
775+
776+
for bond in has_bonds:
763777
atom_start_index = bond.GetBeginAtomIdx()
764778
atom_end_index = bond.GetEndAtomIdx()
765779

@@ -785,12 +799,21 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
785799

786800
updates.extend([bond_to, bond_from])
787801

788-
coordinates = tensor(coordinates).long()
789-
updates = tensor(updates).long()
802+
if num_bonds > 0:
803+
coordinates = tensor(coordinates).long()
804+
updates = tensor(updates).long()
790805

791-
mol_atompair_ids = einx.set_at(
792-
"[h w], c [2], c -> [h w]", mol_atompair_ids, coordinates, updates
793-
)
806+
# mol_atompair_ids = einx.set_at("[h w], c [2], c -> [h w]", mol_atompair_ids, coordinates, updates)
807+
808+
molpair_strides = tensor(mol_atompair_ids.stride())
809+
flattened_coordinates = (coordinates * molpair_strides).sum(dim = -1)
810+
811+
packed_atompair_ids, unpack_one = pack_one(mol_atompair_ids, '*')
812+
packed_atompair_ids[flattened_coordinates] = updates
813+
814+
mol_atompair_ids = unpack_one(packed_atompair_ids)
815+
816+
# /einx.set_at
794817

795818
row_col_slice = slice(offset, offset + num_atoms)
796819
atompair_ids[row_col_slice, row_col_slice] = mol_atompair_ids
@@ -1110,11 +1133,12 @@ def molecule_lengthed_molecule_input_to_atom_input(
11101133

11111134
if mol_is_one_token_per_atom:
11121135
coordinates = []
1113-
updates = []
11141136

11151137
has_bond = torch.zeros(num_atoms, num_atoms).bool()
1138+
bonds = mol.GetBonds()
1139+
num_bonds = len(bonds)
11161140

1117-
for bond in mol.GetBonds():
1141+
for bond in bonds:
11181142
atom_start_index = bond.GetBeginAtomIdx()
11191143
atom_end_index = bond.GetEndAtomIdx()
11201144

@@ -1125,12 +1149,19 @@ def molecule_lengthed_molecule_input_to_atom_input(
11251149
]
11261150
)
11271151

1128-
updates.extend([True, True])
1152+
if num_bonds > 0:
1153+
coordinates = tensor(coordinates).long()
11291154

1130-
coordinates = tensor(coordinates).long()
1131-
updates = tensor(updates).bool()
1155+
# has_bond = einx.set_at("[h w], c [2], c -> [h w]", has_bond, coordinates, updates)
11321156

1133-
has_bond = einx.set_at("[h w], c [2], c -> [h w]", has_bond, coordinates, updates)
1157+
has_bond_stride = tensor(has_bond.stride())
1158+
flattened_coordinates = (coordinates * has_bond_stride).sum(dim = -1)
1159+
packed_has_bond, unpack_has_bond = pack_one(has_bond, '*')
1160+
1161+
packed_has_bond[flattened_coordinates] = True
1162+
has_bond = unpack_has_bond(packed_has_bond, '*')
1163+
1164+
# / ein.set_at
11341165

11351166
row_col_slice = slice(offset, offset + num_atoms)
11361167
token_bonds[row_col_slice, row_col_slice] = has_bond
@@ -1279,9 +1310,7 @@ def molecule_lengthed_molecule_input_to_atom_input(
12791310
coordinates = tensor(coordinates).long()
12801311
updates = tensor(updates).long()
12811312

1282-
mol_atompair_ids = einx.set_at(
1283-
"[h w], c [2], c -> [h w]", mol_atompair_ids, coordinates, updates
1284-
)
1313+
# mol_atompair_ids = einx.set_at("[h w], c [2], c -> [h w]", mol_atompair_ids, coordinates, updates)
12851314

12861315
row_col_slice = slice(offset, offset + num_atoms)
12871316
atompair_ids[row_col_slice, row_col_slice] = mol_atompair_ids
@@ -4638,7 +4667,7 @@ def maybe_transform_to_atom_input(i: Any, raise_exception: bool = False) -> Atom
46384667
try:
46394668
return maybe_to_atom_fn(i)
46404669
except Exception as e:
4641-
logger.error(f"Failed to convert input {i} to AtomInput due to: {e}")
4670+
logger.error(f"Failed to convert input {i} to AtomInput due to: {e}, {traceback.format_exc()}")
46424671
if raise_exception:
46434672
raise e
46444673
return None

alphafold3_pytorch/utils/model_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import einx
55
import torch
66
import torch.nn.functional as F
7-
from einops import einsum, pack, rearrange, reduce, repeat, unpack
7+
from einops import einsum, pack, unpack, rearrange, reduce, repeat
88
from torch import Tensor
99
from torch.nn import Module
1010

@@ -17,6 +17,7 @@
1717

1818
# helper functions
1919

20+
2021
# default scheduler used in paper w/ warmup
2122

2223

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.5.32"
3+
version = "0.5.34"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)